Source code for smpl_parallel.parallel

"""Simplified parallelization."""
import multiprocessing
from multiprocessing import Process, Queue

import numpy as np

from smpl.doc import append_doc


[docs]def queued(q, f, *args, **kwargs): q.put(f(*args, **kwargs))
[docs]def res(a): """ Parallel evaluation of the list generator from :func:`gen`. Return parallel executed values. Examples -------- >>> def twice(x): ... return x+x >>> for r in [calc(twice,i) for i in range(0,5)]: ... print(res(r)) 0 2 4 6 8 >>> res([calc(lambda x : x**3, i) for i in range(0,5)]) [0, 1, 8, 27, 64] """ if isinstance(a, list): return [next(k) for k in a] return next(a)
[docs]def gen(f, *args, **kwargs): """Generates parallel execution list generator.""" q = Queue() p = Process(target=queued, args=(q, f, *args), kwargs=kwargs) yield p.start() yield [q.get(), p.join()][0]
[docs]@append_doc(res) def calc(f, *args, **kwargs): g = gen(f, *args, **kwargs) next(g) return g
[docs]def par(f, *args, **kwargs): """ Parallel execution of f on each element of args and kwargs Examples -------- >>> par(lambda x : x**2, range(0,5)) [0, 1, 4, 9, 16] """ return res( [ calc( f, *[args[k][i] for k in range(len(args))], **{k: v[i] for k, v in kwargs.items()} ) for i in range( len(args[0]) if len(args) > 0 else len(next(iter(kwargs.values()))) ) ] )
# TODO rework to same properties as par
[docs]def partitioned_parallel(f, arr, n_jobs=None): """ Parallel execution of f on each element of args Examples -------- >>> partitioned_parallel(lambda x : x**2, range(0,5)) [0, 1, 4, 9, 16] """ n_jobs = n_jobs or multiprocessing.cpu_count() sa = np.array_split(np.array(arr), (len(arr) - 1) // n_jobs + 1) res = [] for i, _ in enumerate(sa): res += par(f, sa[i]) return res
parallel = partitioned_parallel if __name__ == "__main__": import doctest doctest.testmod()