Implementing multiprocessing.pool.ThreadPool from Python in Rust

In this post, we will implement multiprocessing.pool.ThreadPool from Python in Rust. It represents a thread-oriented version of multiprocessing.Pool, which offers a convenient means of parallelizing the execution of a function across multiple input values by distributing the input data across processes. We will use an existing thread-pool implementation and focus on adjusting its interface to match that of multiprocessing.pool.ThreadPool.

Goal

Our objective is to implement the following Python code in Rust:

# Python
import multiprocessing.pool

with multiprocessing.pool.ThreadPool() as pool:
    for result in pool.imap(lambda n: n * 2, [1, 2, 3, 4, 5]):
        print(result)

We will aim for a similar interface:

// Rust
use multiprocessing::ThreadPool;

let pool = ThreadPool::new();
for result in pool.imap(|n| n * 2, &[1, 2, 3, 4, 5]) {
    println("{}", result);
}

Several notes before we start:

  • Do not bother searching for the documentation of multiprocessing.pool.ThreadPool as this class is not documented (see issue #17140). Instead, use the documentation for multiprocessing.Pool. Just mentally replace all occurrences of process with thread.
  • We will implement a thread pool instead of a process pool due to the lack of a multiprocessing.Process equivalent in Rust. Rust has processes, but you can only spawn them by running external commands, akin to the subprocess module in Python.

Alright. Let’s start with the representation of our thread pool.

Representing a Thread Pool

Since the goal of the present post is to implement an interface similar to the one in Python, we will use an already existing implementation of a thread pool. Otherwise, we would have to implement the complete thread pool ourselves, which would require an entire blog post on its own.

We will use the threadpool crate, which provides a simple thread pool for parallel task execution in Rust. Our representation will wrap the provided thread pool:

extern crate threadpool;

pub struct ThreadPool {
    pool: threadpool::ThreadPool,
}

impl ThreadPool {
    // Here, we will implement methods for our thread pool.
    // ...
}

Creating a New Thread Pool

There will be two ways of creating a new thread pool. The first one will detect the number of worker threads automatically from the number of available CPUs (cores) in the system:

let pool = ThreadPool::new();

The second way will allow the user to specify the number of workers:

let pool = ThreadPool::with_workers(8);

We have to choose different method names due to the lack of default arguments and overloading in Rust. Here, it is not really a big deal as we can choose a more descriptive name for the second version.

To detect the number of CPUs (cores) in the system, we use the num_cpus create:

extern crate num_cpus;

This allows us to implement both of the methods above as follows:

impl ThreadPool {
    pub fn new() -> Self {
        let worker_count = num_cpus::get();
        ThreadPool::with_workers(worker_count)
    }

    pub fn with_workers(count: usize) -> Self {
        assert!(count > 0, format!("worker count cannot be {}", count));

        ThreadPool {
            pool: threadpool::ThreadPool::new(count),
        }
    }
}

As you can see, to implement new(), we have utilized with_workers(). Moreover, when the number of workers is negative or zero, we panic. We could have returned a Result instead, but passing an invalid worker count can be considered as a pre-condition violation, so assert is fine in this case.

Destroying the Pool

In the Python code, the pool is destroyed as soon as we get out of the with block:

with multiprocessing.pool.ThreadPool() as pool:
    # Here, you can work with the pool.
    # ...

# Here, the pool has been terminated, so we cannot use it anymore.

Note that when exiting the block, ThreadPool.__exit__() calls ThreadPool.terminate() (see pool.py#L622), which stops the worker threads immediately without completing outstanding work. This matches the behavior of multiprocessing.Pool. Such behavior is sometimes not what you would expect (see this or this bug report), but this is how it has been implemented, so we have to live with it. From my point of view, calling close() instead of terminate() would be better as it would cause the pool to wait until all workers have finished their jobs. Anyway, let’s move on.

In Rust, the pool is automatically destroyed (dropped in Rust parlance) once it goes out of scope:

{
    let pool = ThreadPool::new();
    // ...

} // Here, 'pool' goes out of scope, so it is dropped.

One can specify what should happen upon the drop of a struct by implementing the Drop trait for that struct. If you do not implement the trait, Rust will simply drop all fields of the struct. Since our struct utilizes an existing thread-pool implementation that implements Drop, we can omit the implementation of Drop on our thread pool, which is nice.

Now, let’s get to the meat of our implementation.

Implementing imap()

We start by implementing the imap() method. It takes a function and an iterable, schedules an application of the function to every item of the iterable, and returns an iterator yielding the results. Additionally, the order of the returned results is preserved. This means that if you pass it the increment function and [1, 2, 3, 4], you will get 2, 3, 4, 5 and not e.g. 4, 3, 2, 5.

The imap() method will have the following interface (the commented-out parts will be filled later):

pub fn imap<F, I, T, R>(&self, f: F, inputs: I) -> IMapIterator<R>
    where F: Fn(T) -> R /* ... */,
          I: IntoIterator<Item = T>,
          T: /* ... */,
          R: /* ... */,
{

As you can see, it is generic, because the type of inputs and results may differ between subsequent imap() calls. The first parameter is a unary function from T to R. The second parameter is anything that can be converted into an iterator that yields items of type T. This will allow us to e.g. pass vectors directly into the function, without a need of calling into_iter() on them first. Our imap() method returns a custom type, called IMapIterator, parametrized by the type of results that f produces.

Next, we will implement the body of imap(). Then, we will show how to implement IMapIterator.

Implementing imap()

Before diving into details, let me give you a gist of how the method will work. We will iterate over inputs and submit f(input) for computation in the underlying thread pool. To transfer results between the threads and the caller of imap(), we will use a channel. The sending end of the channel will be passed to the threads. The receiving end will be stored in the returned IMapIterator so the results can be retrieved when iterating over it.

Here is the implementation. First, we need to wrap the function in an Arc (a thread-safe reference-counted pointer) so it can be passed safely to multiple threads:

let f = Arc::new(f);

Then, we need to create the channel to transfer results between the threads and the caller:

let (tx, rx) = mpsc::channel();

Finally, we iterate over the inputs and pass f(input) for computation in the underlying thread pool. Note that we will need to keep track of how many inputs we have submitted to implement IMapIterator later. Here is the loop:

let mut total = 0;
for (i, input) in inputs.into_iter().enumerate() {
    total += 1;
    let f = f.clone();
    let tx = tx.clone();
    self.pool.execute(move || {
        let result = f(input);
        if let Err(_) = tx.send((i, result)) {
            // Sending of the result has failed, which means that the
            // receiving side has hung up. There is nothing to do but
            // to ignore the result.
        }
    });
}

As you can see, we send the “index” of each input alongside the result. This will help us preserving the order when yielding the results to the caller.

Finally, we return an iterator over the results, whose implementation will be given shortly:

IMapIterator::new(rx, total)

The above implementation gives rise to the following complete signature of imap():

pub fn imap<F, I, T, R>(&self, f: F, inputs: I) -> IMapIterator<R>
    where F: Fn(T) -> R + Send + Sync + 'static,
          I: IntoIterator<Item = T>,
          T: Send + 'static,
          R: Send + 'static,
{

As you can see, all of F, T, and R have to be sendable between threads. Additionally, if any of them contains references, such references have to be valid for the entire duration of the program (the 'static lifetime). This stems from the fact that the threads in the pool may outlive the lifetime they have been created in. To explain, consider the following scenario. You create the pool inside a function and pass in a variable containing a reference to a local variable. Then, when the local scope ends, the referenced variable is dropped, so the thread could possibly access a dangling reference. Rust protects us from doing that via the concept of lifetimes.

Lastly, the function that we compute has to be Sync as we are sharing it between threads via Arc.

Implementing IMapIterator

Now, it’s time to implement the iterator returned from imap(). Internally, it will need to contain the following data:

  • The receiving end of the channel, which produces pairs of the form (index of the input, corresponding result).
  • A mapping of indexes into the corresponding results. This is needed to preserve the order when yielding the results to the caller. To explain, consider a situation in which the result for the second input item is computed before the result for the first item. When consuming results from the channel, we need to remember this second result and continue reading results from the channel until we receive a result for the first item. After we yield the first result, we can move forward.
  • The index of the input item for which we should yield its result in the next step.
  • The total number of results to consume from the channel. We have to know this number a priori as reading from a channel blocks the caller until there is something in that channel. If we did not count the number of results we have received, we would get stuck waiting for something that will never arrive.

In Rust, we will use the following representation:

pub struct IMapIterator<T> {
    rx: mpsc::Receiver<(usize, T)>,
    results: BTreeMap<usize, T>,
    next: usize,
    total: usize,
}

Our struct will have the following method, which creates a new iterator from the passed receiving-end of a channel (rx) and expected number of results (total):

impl<T> IMapIterator<T> {
    fn new(rx: mpsc::Receiver<(usize, T)>, total: usize) -> Self {
        IMapIterator {
            rx: rx,
            results: BTreeMap::new(),
            next: 0,
            total: total,
        }
    }
}

To enable iteration over IMapIterator instances, we have to implement the Iterator trait:

impl<T> Iterator for IMapIterator<T> {
    // Type of items that the iterator yields.
    type Item = T;

    fn next(&mut self) -> Option<Self::Item> {
        // Yield Some(item) or None when there are no more items to yield.
        // ...
    }
}

In next(), we will keep looping until we have received all results, yielding them in the process. When we finish, we start yielding None, which is a way of saying: there are no more items to yield, so stop iterating.

while self.next < self.total {
    // Yield the next item and increment self.next.
    // ...
}
None

To yield the next item, we have to consider two scenarios. The first of them happens when we have already received a result for the next index:

if let Some(result) = self.results.remove(&self.next) {
    self.next += 1;
    return Some(result);
}

Note that when we have the result, we can remove it from the map via BTreeMap::remove() as it is no longer needed. This lowers the memory footprint of the iterator.

The second scenario happens when we have not yet received the result we are waiting for. In such a case, we receive the next result, store it in the map, and repeat the loop:

let (i, result) = match self.rx.recv() {
    Ok((i, result)) => (i, result),
    Err(_) => {
        // Receiving has failed, which means that the sending side
        // has hung up. There will be no more results.
        self.next = self.total;
        break;
    },
};
self.results.insert(i, result);
continue;

And that’s it. We have successfully finished the implementation of imap().

Implementing Other Methods

Apart from imap(), multiprocessing.Pool also provides other methods. Now that you know how to write imap(), two of the methods can be implemented analogously. Thus, I will only give the gist behind their implementation:

  • map(): In contrast to imap(), map() blocks until all results are computed, and returns them in a list. This method can be implemented by an internal use imap() as follows:
    	pub fn map<F, I, T, R>(&self, f: F, inputs: I) -> Vec<R>
    		where /* ... */
    	{
    		self.imap(f, inputs).collect::<Vec<_>>()
    	}
    	
  • imap_unordered(): The name says it all. Instead of preserving the order like we did for imap(), we can yield a result as soon as we receive it. This greatly simplifies the implementation as we no longer have to keep track of the results. Thus, when implementing an iterator for imap_unordered(), we can simply keep receiving and immediately yielding the results without any bookkeeping (the BTreeMap instance in the case of imap()). This makes imap_unordered() faster and more efficient in terms of the used memory.

Full disclaimer: There are still some other methods and parameters which we have skipped (such as apply() or the chunksize parameter).

Complete Source Code

The complete source code is available on GitHub.

Discussion

Apart from comments below, you can also discuss this post at /r/rust.

Leave a Comment.