Async IO in PyFlink with Flask and Redis

Buckle in and focus because this is a heavy example that shows the power of PyFlink. First thing however is you may have noticed there is no Async I/O | Apache Flink for PyFlink. Fortunately I’ll be showing you a way around this. In this code we will see:

  1. Use state to determine logic
  2. Async IO using ThreadPoolExecutor
  3. Making requests to Flask
  4. Pushing results immediately to Redis
{'color': 'red', 'value': 3}
{'color': 'red', 'value': 7}
{'color': 'red', 'value': 7}
{'color': 'red', 'value': 4}

Okay this function will get pretty big so I will explain block by block, first import stuff:

from pyflink.common.typeinfo import Types
from pyflink.common import Configuration
from pyflink.datastream import StreamExecutionEnvironment, RuntimeContext, MapFunction
from pyflink.datastream.functions import KeyedProcessFunction
from pyflink.datastream.state import ValueStateDescriptor, ListStateDescriptor, MapStateDescriptor
from pyflink.datastream.connectors import FlinkKafkaConsumer, KafkaSource
from pyflink.datastream.connectors.kafka import KafkaOffsetsInitializer, KafkaOffsetResetStrategy
from pyflink.common.serialization import SimpleStringSchema

import json
import requests
import redis
from concurrent.futures import ThreadPoolExecutor

Okay now that we’ve imported stuff we need to make our KeyedProcessFunction. This function groups by a specific field, in our case, ‘color’ and will keep memory of that group.

properties = {'bootstrap.servers': 'server01:1111','group.id': 'test-topic'}
source = FlinkKafkaConsumer("flink-major-incs", SimpleStringSchema(), properties).set_start_from_earliest()
topic_stream = env.add_source(source).map(json.loads)
keyed_ds = topic_stream.key_by(lambda i: i["color"])
result_stream = keyed_ds.process(PowerfulStuff())
result_stream.print()

Okay, now the fun stuff: the KeyedProcessFunction. The __init__ method is where you can define some variables that do not depend on Flink’s runtime context. The open method, however, is where we need to define variables related to Flink’s runtime context.

We create speed_state with a key type of Types.STRING(), and its value is a mapping (or dictionary) with keys that are strings and values that are PICKLED_BYTE_ARRAYS, which can store many types of data.

class PowerfulStuff(KeyedProcessFunction):
    def __init__(self):
        self.speed_state = None
        self.include_fields = ['value']

    def open(self, runtime_context: RuntimeContext):
        # keep the memory of the records here
        descriptor = MapStateDescriptor("speed_state", Types.STRING(), Types.MAP(Types.STRING(), Types.PICKLED_BYTE_ARRAY()))
        self.speed_state = runtime_context.get_map_state(descriptor)

    def process_element(self, record, ctx: 'KeyedProcessFunction.Context'):
        current_key = ctx.get_current_key()
        previous_record = self.speed_state.get(current_key)
        # put current record into current memory
        self.speed_state.put(current_key, record)

Since this can be a bit confusing, let’s walk through the data step-by-step as it gets put into memory (or state) from the Kafka topic:

First Record: {'color': 'red', 'value': 3}

  • This first goes into process_element.
  • We get current_key = ctx.get_current_key(), which is 'red'.
  • We then get the previous_record, which is None for now.
  • We take the current record and store it in the state.

Next Record: {'color': 'red', 'value': 7}

  • We get the current_key, which is 'red'.
  • When we check the state, we see the previous record stored there!
  • Now, in our function, we have access to both the previous record and the current record.

This is a simple example of state management, but you can go crazy with it.

Okay, now the fun stuff: the KeyedProcessFunction. The __init__ method is where you can define some variables that do not depend on Flink’s runtime context. The open method, however, is where we need to define variables related to Flink’s runtime context.

We create speed_state with a key type of Types.STRING(), and its value is a mapping (or dictionary) with keys that are strings and values that are PICKLED_BYTE_ARRAYS, which can store many types of data.

Anyways let’s continue on:

    def process_element(self, record, ctx: 'KeyedProcessFunction.Context'):
        current_key = ctx.get_current_key()
        previous_record = self.speed_state.get(current_key)
        # put current record into current memory
        self.speed_state.put(current_key, record)
        with requests.Session() as session, StrictRedis(host='server123', port=1234, decode_responses=True) as redis_client:
            if (record.get('value') == 4):
                self.clear_redis(record, redis_client)
                yield record
            elif self.was_record_updated(previous_record, record):
                #Lets do some threadpool async io logic
                with ThreadPoolExecutor(max_workers=2) as executor:
                    future_square = executor.submit(self.the_square, record, session, redis_client)
                    future_is_prime = executor.submit(self.is_prime, record, session, redis_client)
                    record['square'] = str(future_square)
                    record['is_prime'] = str(future_is_prime)
                yield record
    def was_record_updated(self, previous_record, current_record):
        # Return true if brand new record
        if previous_record is None:
            return True
        # see if a field we care about has been updated
        for key in current_record:
            if key in self.include_fields and current_record[key] != previous_record[key]:
                return True
        return False

We create the API session and the Redis session and see our first if statement. Our current record’s value is 3, so we move to the elif. Here we see elif self.was_record_updated(previous_record, record):. The function gives logic on whether or not we want to consider a record updated. In this case, our self.include_fields does care about ‘value’ changing, so it will return True and we will enter our threadpool async IO.

    def the_square(self, record, session, redis_client):
        # API call that gives you the square
        value_squared = session.post(url, json={'value': record['value']}).text
        redis_client.set("the_key", value_squared)
        return value_squared

    def is_prime(self, record, session, redis_client):
        # API call that checks if the number is prime
        is_prime = session.post(url, json={'value': record['value']}).text
        redis_client.set("prime_key", is_prime)
        return is_prime

Thanks to async IO, both the_square and is_prime are run simultaneously and output their results to Redis simultaneously. Here’s the full code, obviously some code is missing like self.clear_redis but you get the idea. Implement your own logic and let Flink carry the workload. Also, for production jobs, look into checkpointing and job restarts as well and be sure to properly set up your Kafka consumer & group IDs

from pyflink.common.typeinfo import Types
from pyflink.common import Configuration
from pyflink.datastream import StreamExecutionEnvironment, RuntimeContext, MapFunction
from pyflink.datastream.functions import KeyedProcessFunction
from pyflink.datastream.state import ValueStateDescriptor, ListStateDescriptor, MapStateDescriptor
from pyflink.datastream.connectors import FlinkKafkaConsumer, KafkaSource
from pyflink.datastream.connectors.kafka import KafkaOffsetsInitializer, KafkaOffsetResetStrategy
from pyflink.common.serialization import SimpleStringSchema

import json
import requests
import redis
from concurrent.futures import ThreadPoolExecutor

class PowerfulStuff(KeyedProcessFunction):
    def __init__(self):
        self.speed_state = None
        self.include_fields = ['value']

    def open(self, runtime_context: RuntimeContext):
        # keep the memory of the records here
        descriptor = MapStateDescriptor("speed_state", Types.STRING(), Types.MAP(Types.STRING(), Types.STRING()))
        self.speed_state = runtime_context.get_map_state(descriptor)

    def process_element(self, record, ctx: 'KeyedProcessFunction.Context'):
        current_key = ctx.get_current_key()
        previous_record = self.speed_state.get(current_key)
        # put current record into current memory
        self.speed_state.put(current_key, record)

        with requests.Session() as session, StrictRedis(host='server123', port=1234, decode_responses=True) as redis_client:
            if (record.get('value') == 4):
                self.clear_redis(record, redis_client)
                yield record
				
            elif self.was_record_updated(previous_record, record):
                #Lets do some threadpool async io logic
                with ThreadPoolExecutor(max_workers=2) as executor:
                    future_square = executor.submit(self.the_square, record, session, redis_client)
                    future_is_prime = executor.submit(self.is_prime, record, session, redis_client)
                    record['square'] = str(future_square)
                    record['is_prime'] = str(future_is_prime)
                yield record
            else:
                #implement other logic
                yield record
				
    def was_record_updated(self, previous_record, current_record):
        # Return true if brand new record
        if previous_record is None:
            return True
        # see if a field we care about has been updated
        for key in current_record:
            if key in self.include_fields and current_record[key] != previous_record[key]:
                return True
        return False
		
    def the_square(self, record, session, redis_client):
        # API call that gives you the square
        value_squared = session.post(url, json={'value': record['value']}).text
        redis_client.set("the_key", value_squared)
        return value_squared

    def is_prime(self, record, session, redis_client):
        # API call that checks if the number is prime
        is_prime = session.post(url, json={'value': record['value']}).text
        redis_client.set("prime_key", is_prime)
        return is_prime

properties = {'bootstrap.servers': 'server01:1111','group.id': 'test-topic'}

# Create Kafka consumer
source = FlinkKafkaConsumer("flink-major-incs", SimpleStringSchema(), properties).set_start_from_earliest()
topic_stream = env.add_source(source).map(json.loads)
keyed_ds = topic_stream.key_by(lambda i: i["uniqueID"])
result_stream = keyed_ds.process(PowerfulStuff())
result_stream.print()