Buckle in and focus because this is a heavy example that shows the power of PyFlink. First thing however is you may have noticed there is no Async I/O | Apache Flink for PyFlink. Fortunately I’ll be showing you a way around this. In this code we will see:
- Use state to determine logic
- Async IO using ThreadPoolExecutor
- Making requests to Flask
- Pushing results immediately to Redis
{'color': 'red', 'value': 3}
{'color': 'red', 'value': 7}
{'color': 'red', 'value': 7}
{'color': 'red', 'value': 4}
Okay this function will get pretty big so I will explain block by block, first import stuff:
from pyflink.common.typeinfo import Types
from pyflink.common import Configuration
from pyflink.datastream import StreamExecutionEnvironment, RuntimeContext, MapFunction
from pyflink.datastream.functions import KeyedProcessFunction
from pyflink.datastream.state import ValueStateDescriptor, ListStateDescriptor, MapStateDescriptor
from pyflink.datastream.connectors import FlinkKafkaConsumer, KafkaSource
from pyflink.datastream.connectors.kafka import KafkaOffsetsInitializer, KafkaOffsetResetStrategy
from pyflink.common.serialization import SimpleStringSchema
import json
import requests
import redis
from concurrent.futures import ThreadPoolExecutor
Okay now that we’ve imported stuff we need to make our KeyedProcessFunction. This function groups by a specific field, in our case, ‘color’ and will keep memory of that group.
properties = {'bootstrap.servers': 'server01:1111','group.id': 'test-topic'}
source = FlinkKafkaConsumer("flink-major-incs", SimpleStringSchema(), properties).set_start_from_earliest()
topic_stream = env.add_source(source).map(json.loads)
keyed_ds = topic_stream.key_by(lambda i: i["color"])
result_stream = keyed_ds.process(PowerfulStuff())
result_stream.print()
Okay, now the fun stuff: the KeyedProcessFunction
. The __init__
method is where you can define some variables that do not depend on Flink’s runtime context. The open
method, however, is where we need to define variables related to Flink’s runtime context.
We create speed_state
with a key type of Types.STRING()
, and its value is a mapping (or dictionary) with keys that are strings and values that are PICKLED_BYTE_ARRAYS
, which can store many types of data.
class PowerfulStuff(KeyedProcessFunction):
def __init__(self):
self.speed_state = None
self.include_fields = ['value']
def open(self, runtime_context: RuntimeContext):
# keep the memory of the records here
descriptor = MapStateDescriptor("speed_state", Types.STRING(), Types.MAP(Types.STRING(), Types.PICKLED_BYTE_ARRAY()))
self.speed_state = runtime_context.get_map_state(descriptor)
def process_element(self, record, ctx: 'KeyedProcessFunction.Context'):
current_key = ctx.get_current_key()
previous_record = self.speed_state.get(current_key)
# put current record into current memory
self.speed_state.put(current_key, record)
Since this can be a bit confusing, let’s walk through the data step-by-step as it gets put into memory (or state) from the Kafka topic:
First Record: {'color': 'red', 'value': 3}
- This first goes into
process_element
. - We get
current_key = ctx.get_current_key()
, which is'red'
. - We then get the
previous_record
, which isNone
for now. - We take the current record and store it in the state.
Next Record: {'color': 'red', 'value': 7}
- We get the
current_key
, which is'red'
. - When we check the state, we see the previous record stored there!
- Now, in our function, we have access to both the previous record and the current record.
This is a simple example of state management, but you can go crazy with it.
Okay, now the fun stuff: the KeyedProcessFunction
. The __init__
method is where you can define some variables that do not depend on Flink’s runtime context. The open
method, however, is where we need to define variables related to Flink’s runtime context.
We create speed_state
with a key type of Types.STRING()
, and its value is a mapping (or dictionary) with keys that are strings and values that are PICKLED_BYTE_ARRAYS
, which can store many types of data.
Anyways let’s continue on:
def process_element(self, record, ctx: 'KeyedProcessFunction.Context'):
current_key = ctx.get_current_key()
previous_record = self.speed_state.get(current_key)
# put current record into current memory
self.speed_state.put(current_key, record)
with requests.Session() as session, StrictRedis(host='server123', port=1234, decode_responses=True) as redis_client:
if (record.get('value') == 4):
self.clear_redis(record, redis_client)
yield record
elif self.was_record_updated(previous_record, record):
#Lets do some threadpool async io logic
with ThreadPoolExecutor(max_workers=2) as executor:
future_square = executor.submit(self.the_square, record, session, redis_client)
future_is_prime = executor.submit(self.is_prime, record, session, redis_client)
record['square'] = str(future_square)
record['is_prime'] = str(future_is_prime)
yield record
def was_record_updated(self, previous_record, current_record):
# Return true if brand new record
if previous_record is None:
return True
# see if a field we care about has been updated
for key in current_record:
if key in self.include_fields and current_record[key] != previous_record[key]:
return True
return False
We create the API session and the Redis session and see our first if
statement. Our current record’s value is 3, so we move to the elif
. Here we see elif self.was_record_updated(previous_record, record):
. The function gives logic on whether or not we want to consider a record updated. In this case, our self.include_fields
does care about ‘value’ changing, so it will return True
and we will enter our threadpool async IO.
def the_square(self, record, session, redis_client):
# API call that gives you the square
value_squared = session.post(url, json={'value': record['value']}).text
redis_client.set("the_key", value_squared)
return value_squared
def is_prime(self, record, session, redis_client):
# API call that checks if the number is prime
is_prime = session.post(url, json={'value': record['value']}).text
redis_client.set("prime_key", is_prime)
return is_prime
Thanks to async IO, both the_square
and is_prime
are run simultaneously and output their results to Redis simultaneously. Here’s the full code, obviously some code is missing like self.clear_redis
but you get the idea. Implement your own logic and let Flink carry the workload. Also, for production jobs, look into checkpointing and job restarts as well and be sure to properly set up your Kafka consumer & group IDs
from pyflink.common.typeinfo import Types
from pyflink.common import Configuration
from pyflink.datastream import StreamExecutionEnvironment, RuntimeContext, MapFunction
from pyflink.datastream.functions import KeyedProcessFunction
from pyflink.datastream.state import ValueStateDescriptor, ListStateDescriptor, MapStateDescriptor
from pyflink.datastream.connectors import FlinkKafkaConsumer, KafkaSource
from pyflink.datastream.connectors.kafka import KafkaOffsetsInitializer, KafkaOffsetResetStrategy
from pyflink.common.serialization import SimpleStringSchema
import json
import requests
import redis
from concurrent.futures import ThreadPoolExecutor
class PowerfulStuff(KeyedProcessFunction):
def __init__(self):
self.speed_state = None
self.include_fields = ['value']
def open(self, runtime_context: RuntimeContext):
# keep the memory of the records here
descriptor = MapStateDescriptor("speed_state", Types.STRING(), Types.MAP(Types.STRING(), Types.STRING()))
self.speed_state = runtime_context.get_map_state(descriptor)
def process_element(self, record, ctx: 'KeyedProcessFunction.Context'):
current_key = ctx.get_current_key()
previous_record = self.speed_state.get(current_key)
# put current record into current memory
self.speed_state.put(current_key, record)
with requests.Session() as session, StrictRedis(host='server123', port=1234, decode_responses=True) as redis_client:
if (record.get('value') == 4):
self.clear_redis(record, redis_client)
yield record
elif self.was_record_updated(previous_record, record):
#Lets do some threadpool async io logic
with ThreadPoolExecutor(max_workers=2) as executor:
future_square = executor.submit(self.the_square, record, session, redis_client)
future_is_prime = executor.submit(self.is_prime, record, session, redis_client)
record['square'] = str(future_square)
record['is_prime'] = str(future_is_prime)
yield record
else:
#implement other logic
yield record
def was_record_updated(self, previous_record, current_record):
# Return true if brand new record
if previous_record is None:
return True
# see if a field we care about has been updated
for key in current_record:
if key in self.include_fields and current_record[key] != previous_record[key]:
return True
return False
def the_square(self, record, session, redis_client):
# API call that gives you the square
value_squared = session.post(url, json={'value': record['value']}).text
redis_client.set("the_key", value_squared)
return value_squared
def is_prime(self, record, session, redis_client):
# API call that checks if the number is prime
is_prime = session.post(url, json={'value': record['value']}).text
redis_client.set("prime_key", is_prime)
return is_prime
properties = {'bootstrap.servers': 'server01:1111','group.id': 'test-topic'}
# Create Kafka consumer
source = FlinkKafkaConsumer("flink-major-incs", SimpleStringSchema(), properties).set_start_from_earliest()
topic_stream = env.add_source(source).map(json.loads)
keyed_ds = topic_stream.key_by(lambda i: i["uniqueID"])
result_stream = keyed_ds.process(PowerfulStuff())
result_stream.print()