Source code for pyblaze.multiprocessing.vectorize

import os
import functools
import numpy as np
import torch.multiprocessing as mp
from pyblaze.utils.stdmp import terminate

[docs]class Vectorizer: """ The Vectorizer class ought to be used in cases where a result tensor of size N is filled with values computed in some complex way. The computation of these N computations can then be parallelized over multiple processes. """
[docs] def __init__(self, worker_func, worker_init=None, callback_func=None, num_workers=-1, **kwargs): """ Initializes a new vectorizer. Parameters ---------- worker_func: callable The function which receives as input an item of the input to process and outputs a value which ought to be returned. worker_init: callable, default: None The function receives as input the rank of the worker (i.e. every time this function is called, it is called with a different integer as parameter). Its return values are passed as *last* parameters to `worker_func` upon every invocation. callback_func: callable, default: None A function to call after every item has been processed. Must not need to be a free function as it is called on the main thread. num_workers: int, default: -1 The number of processes to use. If set to -1, it defaults to the number of available cores. If set to 0, everything is executed on the main thread. kwargs: keyword arguments Additional arguments passed to the worker initialization function. """ self.num_workers = os.cpu_count() if num_workers == -1 else num_workers self.worker_func = worker_func self.worker_init = worker_init self.callback_func = callback_func self.init_kwargs = kwargs self._shutdown_fn = None
[docs] def process(self, items, *args): """ Uses the vectorizer's worker function in order to process all items in parallel. Parameters ---------- items: list of object or iterable of object The items to be processed by the workers. If given as iterable only (i.e. it does not support index access), the performance might suffer slightly due to an increased number of synchronizations. args: variadic arguments Additional arguments passed directly to the worker function. Returns ------- list of object The output generated by the worker function for each of the input items. """ if self.num_workers == 0: # execute sequentially result = [] init = _init_if_needed(self.worker_init, 0, **self.init_kwargs) all_args = _combine_args(args, init) for item in items: result.append(self.worker_func(item, *all_args)) return result process_batches = hasattr(items, '__getitem__') if process_batches: result = self._process_batches(items, *args) else: result = self._process_consumers(items, *args) self._shutdown_fn() self._shutdown_fn = None return result
def _process_batches(self, items, *args): num_items = len(items) splits = np.array_split(np.arange(num_items), self.num_workers) splits = [0] + [a[-1] + 1 for a in splits] result = [] processes = [] queues = [] done = mp.Event() if self.callback_func is not None: tick_queue = mp.Queue() else: tick_queue = None for i in range(self.num_workers): queue = mp.Queue() process = mp.Process( target=_batch_worker, args=( queue, done, tick_queue, i, self.worker_func, self.worker_init, self.init_kwargs, items[splits[i]:splits[i+1]], *args ) ) process.daemon = True process.start() processes.append(process) queues.append(queue) self._shutdown_fn = functools.partial( self._shutdown_batches, processes, done ) if self.callback_func is not None: for _ in range(num_items): tick_queue.get() self.callback_func() for i, q in enumerate(queues): result.extend(q.get()) q.close() return result def _shutdown_batches(self, processes, done): done.set() terminate(*processes) def _process_consumers(self, items, *args): result = [] processes = [] push_queue = mp.Queue() pull_queue = mp.Queue() for i in range(self.num_workers): process = mp.Process( target=_consumer_worker, args=(push_queue, pull_queue, i, self.worker_func, self.worker_init, self.init_kwargs, *args) ) process.daemon = True process.start() processes.append(process) self._shutdown_fn = functools.partial( self._shutdown_consumers, processes, pull_queue, push_queue ) iterator = iter(items) index = 0 expect = 0 try: for _ in range(self.num_workers): item = next(iterator) expect += 1 push_queue.cancel_join_thread() push_queue.put((index, item)) index += 1 while True: result.append(pull_queue.get()) if self.callback_func is not None: self.callback_func() expect -= 1 item = next(iterator) expect += 1 push_queue.cancel_join_thread() push_queue.put((index, item)) index += 1 except StopIteration: for _ in range(expect): result.append(pull_queue.get()) if self.callback_func is not None: self.callback_func() return [r[1] for r in sorted(result, key=lambda r: r[0])] def _shutdown_consumers(self, processes, pull_queue, push_queue): pull_queue.close() for _ in range(len(processes)): push_queue.cancel_join_thread() push_queue.put(None) push_queue.close() terminate(*processes) def __del__(self): if self._shutdown_fn is not None: self._shutdown_fn()
def _batch_worker(push_queue, done, tick_queue, rank, worker_func, worker_init, init_kwargs, items, *args): init = _init_if_needed(worker_init, rank, **init_kwargs) all_args = _combine_args(args, init) result = [] for item in items: result.append(worker_func(item, *all_args)) if tick_queue is not None: tick_queue.cancel_join_thread() tick_queue.put(None) push_queue.cancel_join_thread() push_queue.put(result) done.wait() def _consumer_worker(pull_queue, push_queue, rank, worker_func, worker_init, init_kwargs, *args): init = _init_if_needed(worker_init, rank, **init_kwargs) all_args = _combine_args(args, init) while True: item = pull_queue.get() if item is None: break idx, item = item result = worker_func(item, *all_args) push_queue.cancel_join_thread() push_queue.put((idx, result)) def _init_if_needed(init, rank, **kwargs): if init is None: return None return init(rank, **kwargs) def _combine_args(a, b): if b is None: return a if isinstance(b, tuple): return a + b return a + (b,)