Streaming Kafka Joins in PyFlink

You are an almighty stock trader. Your machine learning model absolutely owns. But here’s the thing, if you don’t execute fast you’re left behind and everyone’s already gone to the moon. Anyway, with PyFlink I’ll show you how to get data all in one spot real quick for MLOPs

stock_prices
{'symbol': 'AAPL', 'price': 150.0, 'timestamp': 1633072800}
{'symbol': 'AAPL', 'price': 151.0, 'timestamp': 1633072810}
{'symbol': 'AAPL', 'price': 152.0, 'timestamp': 1633072820}
trading_volumes
{'symbol': 'AAPL', 'volume': 10000, 'timestamp': 1633072800}
{'symbol': 'AAPL', 'volume': 15000, 'timestamp': 1633072810}
{'symbol': 'AAPL', 'volume': 20000, 'timestamp': 1633072820}

Set up the Kafka streams:

from pyflink.datastream import StreamExecutionEnvironment
from pyflink.datastream.connectors import FlinkKafkaConsumer
from pyflink.common.serialization import SimpleStringSchema
from pyflink.common.typeinfo import Types
import json

env = StreamExecutionEnvironment.get_execution_environment()

kafka_props = {
    'bootstrap.servers': 'localhost:9092',
    'group.id': 'flink_consumer'
}

stock_price_consumer = FlinkKafkaConsumer(
    topics='stock_prices',
    deserialization_schema=SimpleStringSchema(),
    properties=kafka_props
)

trading_volume_consumer = FlinkKafkaConsumer(
    topics='trading_volumes',
    deserialization_schema=SimpleStringSchema(),
    properties=kafka_props
)

stock_price_stream = env.add_source(stock_price_consumer).map(lambda x: json.loads(x), output_type=Types.MAP(Types.STRING(), Types.FLOAT()))
trading_volume_stream = env.add_source(trading_volume_consumer).map(lambda x: json.loads(x), output_type=Types.MAP(Types.STRING(), Types.INT()))

Key each stream: This partitions the streams by the stock symbol, ensuring that all records with the same symbol are processed together.

keyed_stock_price_stream = stock_price_stream.key_by(lambda x: x['symbol'])
keyed_trading_volume_stream = trading_volume_stream.key_by(lambda x: x['symbol'])

Union the Streams: This combines the two streams into a single stream.

unioned_stream = keyed_stock_price_stream.union(keyed_trading_volume_stream)

Windowing: This creates a tumbling window of 5 seconds. Basically you need domain knowledge of how fast your data is coming in. There are many different strategies such as adding delay and using watermarks but we are going to keep this relatively simple and just batch together all data over the course of 5 seconds.

from pyflink.datastream.window import TumblingEventTimeWindows
from pyflink.common.time import Time

windowed_stream = unioned_stream.window(TumblingEventTimeWindows.of(Time.seconds(5)))

Reduce Operation: how do we manage the data over 5 seconds. Average it! Here we just take all the data from both streams whether we get a lot or a little, and average it out so that we can put it into our model.

class AverageReduceFunction:
    def reduce(self, value1, value2):
        count1 = value1.get('count', 1)
        count2 = value2.get('count', 1)
        total_count = count1 + count2
        
        total_price = value1.get('price', 0.0) * count1 + value2.get('price', 0.0) * count2
        average_price = total_price / total_count
        
        total_volume = value1.get('volume', 0) + value2.get('volume', 0)
        
        return {
            'symbol': value1['symbol'],
            'price': average_price,
            'volume': total_volume,
            'count': total_count
        }

reduced_stream = windowed_stream.reduce(AverageReduceFunction())

Load the Random Forest Model and Predict: always use the open function to load objects that you are going to use repeatedly as it stores it in memory and doesn’t need to be reloaded ever.

import pickle
from pyflink.datastream.functions import ProcessFunction, RuntimeContext

class DiamondHands(ProcessFunction):
    def open(self, runtime_context: RuntimeContext):
        with open('random_forest_model.pkl', 'rb') as f:
            self.model = pickle.load(f)

    def process_element(self, value, ctx: ProcessFunction.Context):
        features = [[value['price'], value['volume']]]
        probability = self.model.predict_proba(features)[0][1]
        decision = 'buy' if probability > 0.5 else 'sell'
        yield {
            'symbol': value['symbol'],
            'decision': decision,
            'probability': probability
        }

decision_stream = reduced_stream.process(DiamondHands())

Thats it! If your model needs more data you can always pull from a database, or read from redis, or basically anything else thanks to PyFlink. Full code:

from pyflink.datastream import StreamExecutionEnvironment
from pyflink.datastream.connectors import FlinkKafkaConsumer
from pyflink.common.serialization import SimpleStringSchema
from pyflink.common.typeinfo import Types
from pyflink.datastream.window import TumblingEventTimeWindows
from pyflink.common.time import Time
from pyflink.datastream.functions import ProcessFunction, RuntimeContext
import json
import pickle

# Step 1: Set up the Kafka Streams
env = StreamExecutionEnvironment.get_execution_environment()

kafka_props = {
    'bootstrap.servers': 'localhost:9092',
    'group.id': 'flink_consumer'
}

stock_price_consumer = FlinkKafkaConsumer(
    topics='stock_prices',
    deserialization_schema=SimpleStringSchema(),
    properties=kafka_props
)

trading_volume_consumer = FlinkKafkaConsumer(
    topics='trading_volumes',
    deserialization_schema=SimpleStringSchema(),
    properties=kafka_props
)

stock_price_stream = env.add_source(stock_price_consumer).map(lambda x: json.loads(x), output_type=Types.MAP(Types.STRING(), Types.FLOAT()))
trading_volume_stream = env.add_source(trading_volume_consumer).map(lambda x: json.loads(x), output_type=Types.MAP(Types.STRING(), Types.INT()))

# Step 2: KeyBy Operation
keyed_stock_price_stream = stock_price_stream.key_by(lambda x: x['symbol'])
keyed_trading_volume_stream = trading_volume_stream.key_by(lambda x: x['symbol'])

# Step 3: Union the Streams
unioned_stream = keyed_stock_price_stream.union(keyed_trading_volume_stream)

# Step 4: Windowing
windowed_stream = unioned_stream.window(TumblingEventTimeWindows.of(Time.seconds(5)))

# Step 5: Reduce Operation
class AverageReduceFunction:
    def reduce(self, value1, value2):
        count1 = value1.get('count', 1)
        count2 = value2.get('count', 1)
        total_count = count1 + count2
        
        total_price = value1.get('price', 0.0) * count1 + value2.get('price', 0.0) * count2
        average_price = total_price / total_count
        
        total_volume = value1.get('volume', 0) + value2.get('volume', 0)
        
        return {
            'symbol': value1['symbol'],
            'price': average_price,
            'volume': total_volume,
            'count': total_count
        }

reduced_stream = windowed_stream.reduce(AverageReduceFunction())


# Step 6: Load the Random Forest Model and Predict
class DiamondHands(ProcessFunction):
    def open(self, runtime_context: RuntimeContext):
        with open('random_forest_model.pkl', 'rb') as f:
            self.model = pickle.load(f)

    def process_element(self, value, ctx: ProcessFunction.Context):
        features = [[value['price'], value['volume']]]
        probability = self.model.predict_proba(features)[0][1]
        decision = 'buy' if probability > 0.5 else 'sell'
        yield {
            'symbol': value['symbol'],
            'decision': decision,
            'probability': probability
        }

decision_stream = reduced_stream.process(DiamondHands())

# Step 7: Execute the Environment
decision_stream.print()
env.execute("DiamondHands")