Edition 16 : A PySpark idiom for efficient Model Inference
Parallel text encoding using Apache Spark.
The blog is written by one of my great colleagues and AI Expert Sujit Pal. He has a tremendous collection of AI deep-dive posts in his libraries, please do visit.
I recently needed to build an Apache Spark (PySpark) job where the task was (among other things) to use a Language Model (LM) to encode text into vectors. This is an embarrassingly parallel job where the text-to-encoding is one-to-one, so something like Spark works very well here. We could, in theory at least, achieve an N-fold performance improvement by horizontally partitioning the data into N splits respectively, and encoding them using N parallel workers.
However, LMs (and Machine Learning (ML) models in general) usually take some time to initialize before it is ready for use. This initialization step loads the model's parameters (multi-dimensional tensors of weights learned during the training process) into memory. So it is not really feasible to do something like this:
This is because it would require the model to be initialized for each row in our RDD, which can be very time-consuming. We can address this by initializing it on the master and broadcasting it to all the workers, something I have done in the past.
But Spark provides a higher-order function (HOF) specifically for this use case, called mapPartitions, which allows you to specify code to create some heavyweight object(s) per partition, and then apply some processing (using these heavyweight objects) to all rows in the section. So using this idiom, our processing code would look like this. You could also broadcast the model from the master instead of initializing it each time in the workers, which will save you the initialization time on each worker. Regardless, you can think of model.initialize_model as a wrapper for either approach.
However, LMs (and ML models in general) are designed to process input in batches. Generally, inference (at least for neural models) involves a lot of matrix multiplications, which the underlying tensor library does in parallel if you feed your model in batches (or larger sets) rather than one input record at a time. Assuming the model was trained with batch size B (usually indicated by the default value for the batch_size parameter in the encode method (or equivalent)), this would translate roughly into a B-fold performance improvement if you fed it batches of size >= B. The model will internally partition the input into multiple batches of B records each, and process the batches sequentially and records within each batch in parallel.
So to allow the model to consume the rows in batches, we could change our code as follows.
Obviously, the approach above assumes that you have enough memory per partition to hold the text for all the documents in the partition. If the texts in your partition is too large, you will get an Out of Memory (OOM) and the job will abort. So based on your data and your architecture, the simplest (and probably slightly brute force approach) is to repartition your RDD into a larger number of (smaller) partitions, where the texts will fit in memory. So maybe something like this...
However, this can lead to many small partitions, which may be an overhead for Spark since it now has to manage the additional coordination. Also assuming you were initializing the model in the mapPartitions call, the job would spend more time doing this as well if there were many small partitions. Another way (and basically the idiom I am trying to build up to in this blog post) could be to leave the partition intact and use itertools.islice to batch up rows within each partition using code instead of leveraging the side effect of the partition size. Something like this:
I am curious what people think of this approach. Using Spark to run ML inference at scale cannot be a new problem, but I couldn't find any information or best practices about this on the Internet. I did consider the possibility that perhaps my Google-fu may not be as strong as I think, so I also tried Bard, and it didn't give me much to go on either. I am sure many Data Engineers before me have looked at this problem and have their own favorite solutions. Please share in the comments if you can!
**
I will publish the next Edition on Sunday.
This is the 16th Edition and is written by one of my great colleagues and AI Expert Sujit Pal. He has a tremendous collection of AI deep-dive posts in his libraries, please do visit.
If you have any feedback please don’t hesitate to share it with me, And if you love my work, do share it with your colleagues.
It takes time to research and document it - Please be a paid subscriber and support my work.
Cheers!!
Raahul
**