Esempio n. 1
0
    def InitTestNets(self):
        """Initialize the test nets.

        Returns
        -------
        None

        References
        ----------
        The implementation of `InitTestNets(solver.cpp, L104)`_.

        """
        if mpi.Is_Init():
            idx, group = mpi.AllowParallel()
            # only the root in a parallel group can test
            if idx != -1 and mpi.Rank() != group[0]: return

        num_test_net = len(self._param.test_iter)
        if num_test_net > 0:
            if self._param.test_interval <= 0:
                raise RuntimeError('the val of test interval: {} is invaild.')

        if len(self._param.test_net) > 0:
            for test_net in self._param.test_net:
                 self._test_nets.append(Net(test_net, "TEST"))
            num_test_net -= len(self._param.test_net)

        # consider generic_net
        if num_test_net > 0:
            self._test_nets.append(Net(self._param.net, "TEST"))

        # share with training net
        for test_net in self._test_nets: test_net.share_with(self._net)
Esempio n. 2
0
def _inject_update_ops(graph_def, updater):
    """Inject the update ops GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    graph_def : GraphDef
        The definition of graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return
    updater.register_in_workspace()

    grads, update_ops = [], []
    extra_arguments = updater._extra_kwargs
    extra_arguments['slot'] = updater._slot

    # Build update ops according to the updater
    for e in updater._param_group:
        (param, grad), arguments = e
        if _workspace.HasTensor(grad):
            grads.append(grad)
            arguments = dict(arguments, **extra_arguments)
            update_ops.append(
                _proto_utils.MakeOperatorDef(
                    op_type=updater.type(),
                    inputs=[grad],
                    outputs=[param],
                    name=_helper.OperatorHelper.get_name(),
                    **arguments))
        else:
            _logging.info('Skip to update Tensor({}).'.format(param))

    # Check data parallel if necessary
    if _mpi.Is_Init():
        (rank, group), arguments = _mpi.AllowParallel(), {}
        if rank != -1:
            arguments['mode'] = '%s_ALLREDUCE' % _mpi.GetParallelMode()
            arguments['root'], (arguments['comm'], arguments['group']) \
                = group[0], _mpi.CreateGroup(root=group[0], incl=group)
            update_ops.insert(
                0,
                _proto_utils.MakeOperatorDef(
                    op_type='CollectiveUpdate',
                    inputs=grads,
                    outputs=grads,
                    name=_helper.OperatorHelper.get_name(),
                    **arguments))

    graph_def.op.extend(update_ops)
Esempio n. 3
0
def _allreduce(grads):
    if not mpi.Is_Init(): return
    if not isinstance(grads, (list, tuple)): grads = [grads]
    ctx = MakeContext(inputs=grads)
    mode = mpi.GetParallelMode() + '_ALLREDUCE'
    key = 'torch/ops/collective/{}:{}/{}'.format(
        ctx[0].lower(), ctx[1], mode.lower())
    module = get_module(Collective, key, ctx, mode=mode)
    return module.forward(grads)
Esempio n. 4
0
def Snapshot(
    tensors,
    filename,
    prefix='',
    suffix='.bin',
    format='default',
):
    """Snapshot tensors into a binary file.

    Parameters
    ----------
    tensors : list of Tensor or Tensor
        The tensors to be wrote.
    filename : str
        The name of this binary file.
    prefix : str
        The prefix of this binary file.
    suffix : str
        The suffix of this binary file.
    format : str
        The format of this binary file.

    Returns
    -------
    None

    Notes
    -----
    The full file path will be:  ``prefix`` + ``filename`` + ``suffix``.

    Available formats: ['default', 'caffe'].

    """
    file_path = prefix + filename + suffix
    if mpi.Is_Init():
        if not mpi.AllowSnapshot(): return
        file_path = file_path + '.rank.{}'.format(mpi.Rank())

    dir = os.path.split(file_path)[0]
    if len(dir) > 0 and not os.path.exists(dir): os.makedirs(dir)

    if format == 'default':
        state_dict = {}
        for tensor in tensors:
            state_dict[tensor.name] = FetchTensor(tensor)
        with open(file_path, 'wb') as f:
            pickle.dump(state_dict, f, pickle.HIGHEST_PROTOCOL)
        logging.info('Snapshot Model@: ' + file_path)
        logging.info('Model Format: Pickle')
    elif format is 'caffe':
        names = [tensor.name for tensor in tensors]
        _C.Snapshot(file_path, names, 1)
    else:
        raise TypeError('Unknown binary format: {}'.format(format))
Esempio n. 5
0
def Snapshot(
    tensors,
    filename,
    prefix='',
    suffix='.bin',
    format='pickle',
):
    """Serialize tensors into a binary file.

    The filename is formatted as:

        ``prefix`` + ``filename`` + ``suffix``

    Parameters
    ----------
    tensors : list of Tensor or Tensor
        The tensors to be wrote.
    filename : str
        The name of this binary file.
    prefix : str, optional, default=''
        The prefix of this binary file.
    suffix : str, optional, default='.bin'
        The suffix of this binary file.
    format : {'pickle', 'caffe'}, optional
        The format of this binary file.

    Returns
    -------
    None

    """
    file_path = prefix + filename + suffix

    if _mpi.Is_Init():
        if not _mpi.AllowSnapshot(): return
        file_path = file_path + '.rank.{}'.format(_mpi.Rank())

    dir = os.path.split(file_path)[0]
    if len(dir) > 0 and not os.path.exists(dir): os.makedirs(dir)

    if format == 'pickle':
        state_dict = {}
        for tensor in tensors:
            state_dict[tensor.name] = FetchTensor(tensor)
        with open(file_path, 'wb') as f:
            pickle.dump(state_dict, f, pickle.HIGHEST_PROTOCOL)
        _logging.info('Snapshot Model@: ' + file_path)
        _logging.info('Model Format: Pickle')
    elif format == 'caffe':
        names = [tensor.name for tensor in tensors]
        get_default_workspace().Snapshot(file_path, names, 1)
    else:
        raise TypeError('Unknown binary format: ' + format)
Esempio n. 6
0
 def snapshot(self):
     if mpi.Is_Init():
         if not mpi.AllowSnapshot(): return
     net = self.solver.net
     infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
              if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
     filename = (self.solver_param.snapshot_prefix + infix +
                 '_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
     filename = os.path.join(self.output_dir, filename)
     net.save(str(filename))
     logger.info('Wrote snapshot to: {:s}'.format(filename))
     return filename
Esempio n. 7
0
def GraphDef_Update(meta_graph, updater):
    """Inject the update targets into GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    meta_graph : dragon_pb2.GraphDef
        The definition of meta graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return

    updater._prefix = meta_graph.name + '_'
    extra_arguments = updater._extra_kwargs
    extra_arguments['domain'] = updater._prefix
    parallel_arguments = {}

    # wrap hyper-parameters as Tensor for CC
    for k, v in updater._hyper_params.items():
        ws.FeedTensor(updater._prefix + k, np.array([v], dtype=np.float32))

    # check data parallel if necessary
    if mpi.Is_Init():
        idx, group = mpi.AllowParallel()
        if idx != -1:
            parallel_arguments['parallel_mode'] = mpi.GetParallelMode()
            parallel_arguments['comm'], parallel_arguments['group'] \
                = mpi.CreateGroup(root=group[0], incl=group)
            parallel_arguments['root'] = group[0]
        for k, v in parallel_arguments.items():
            meta_graph.arg.add().CopyFrom(MakeArgument(k, v))

    for tuple in updater._tuples:
        tensors = tuple[0]
        arguments = tuple[1]
        kwargs = dict(arguments, **extra_arguments)
        u_target = pb.UpdateTarget()
        u_target.type = updater._type
        _, u_target.name = GetOperatorName()
        for tensor in tensors:
            u_target.tensor.append(tensor)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        meta_graph.u_target.extend([u_target])
Esempio n. 8
0
def GraphDef_Update(meta_graph, updater):
    """Inject the update targets into GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    meta_graph : dragon_pb2.GraphDef
        The definition of meta graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return

    # use graph name if missing slot
    if updater._slot is None:
        updater._slot = meta_graph.name
    extra_arguments = updater._extra_kwargs
    extra_arguments['slot'] = updater._slot
    parallel_arguments = {}

    updater.register_in_workspace()

    # check data parallel if necessary
    if mpi.Is_Init():
        idx, group = mpi.AllowParallel()
        if idx != -1:
            parallel_arguments['parallel_mode'] = mpi.GetParallelMode()
            parallel_arguments['comm'], parallel_arguments['group'] \
                = mpi.CreateGroup(root=group[0], incl=group)
            parallel_arguments['root'] = group[0]
        for k, v in parallel_arguments.items():
            meta_graph.arg.add().CopyFrom(MakeArgument(k, v))

    for e in updater._param_group:
        pair, arguments = e
        kwargs = dict(arguments, **extra_arguments)
        u_target = pb.UpdateTarget()
        u_target.type = updater.type()
        _, u_target.name = GetOperatorName()
        for t in pair:
            u_target.tensor.append(t)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        meta_graph.u_target.extend([u_target])
Esempio n. 9
0
def GraphDef_Update(graph_def, updater):
    """Inject the update targets into GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    graph_def : GraphDef
        The definition of graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return

    extra_arguments = updater._extra_kwargs
    extra_arguments['slot'] = updater._slot
    parallel_arguments = {}

    updater.register_in_workspace()

    # Check data parallel if necessary
    if mpi.Is_Init():
        idx, group = mpi.AllowParallel()
        if idx != -1:
            parallel_arguments['parallel_mode'] = mpi.GetParallelMode()
            parallel_arguments['comm'], parallel_arguments['group'] \
                = mpi.CreateGroup(root=group[0], incl=group)
            parallel_arguments['root'] = group[0]
        for k, v in parallel_arguments.items():
            graph_def.arg.add().CopyFrom(MakeArgument(k, v))

    for e in updater._param_group:
        pair, arguments = e
        kwargs = dict(arguments, **extra_arguments)
        u_target = pb.UpdaterProto()
        u_target.type = updater.type()
        u_target.name = OperatorHelper.get_name()
        u_target.tensor.extend(pair)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        graph_def.updater.extend([u_target])
Esempio n. 10
0
 def __init__(self, params, defaults):
     self.defaults = defaults
     if isinstance(params, _Tensor):
         raise TypeError("params argument given to the optimizer should be "
                         "an iterable of Variables or dicts, but got " +
                         str(type(params)))
     self.state = defaultdict(dict)
     self.param_groups = []
     param_groups = list(params)
     if len(param_groups) == 0:
         raise ValueError("optimizer got an empty parameter list")
     if not isinstance(param_groups[0], dict):
         param_groups = [{'params': param_groups}]
     for param_group in param_groups:
         self.add_param_group(param_group)
     self._update_type = None
     self._allow_parallel = False
     if _mpi.Is_Init():
         rank, _ = _mpi.AllowParallel()
         if rank != -1: self._allow_parallel = True
     self._mutable_parameters = {}
Esempio n. 11
0
    def __init__(self, **kwargs):
        """Construct a ``DataBatch``.

        Parameters
        ----------
        source : str
            The path of database.
        multiple_nodes: boolean
            Whether to split data for multiple parallel nodes. Default is ``False``.
        shuffle : boolean
            Whether to shuffle the data. Default is ``False``.
        num_chunks : int
            The number of chunks to split. Default is ``2048``.
        chunk_size : int
            The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
        padding : int
            The zero-padding size. Default is ``0`` (Disabled).
        fill_value : int
            The value to fill when padding is valid. Default is ``127``.
        crop_size : int
            The crop size. Default is ``0`` (Disabled).
        mirror : boolean
            Whether to flip(horizontally) images. Default is ``False``.
        color_augmentation : boolean
            Whether to distort colors. Default is ``False``.
        min_random_scale : float
            The min scale of the input images. Default is ``1.0``.
        max_random_scale : float
            The max scale of the input images. Default is ``1.0``.
        force_color : boolean
            Set to duplicate channels for gray. Default is ``False``.
        phase : str
            The phase of this operator, ``TRAIN`` or ``TEST``. Default is ``TRAIN``.
        batch_size : int
            The size of a training batch.
        partition : boolean
            Whether to partition batch. Default is ``False``.
        prefetch : int
            The prefetch count. Default is ``5``.

        """
        super(DataBatch, self).__init__()
        # Init mpi
        global_rank = 0
        local_rank = 0
        group_size = 1
        if mpi.Is_Init():
            idx, group = mpi.AllowParallel()
            if idx != -1:  # DataParallel
                global_rank = mpi.Rank()
                group_size = len(group)
                for i, node in enumerate(group):
                    if global_rank == node: local_rank = i
        kwargs['group_size'] = group_size

        # Configuration
        self._prefetch = kwargs.get('prefetch', 5)
        self._num_readers = kwargs.get('num_readers', 1)
        self._num_transformers = kwargs.get('num_transformers', -1)
        self._max_transformers = kwargs.get('max_transformers', 3)
        self._num_fetchers = kwargs.get('num_fetchers', 1)

        # Io-Aware Policy
        if self._num_transformers == -1:
            self._num_transformers = 1
            # Add 1 transformer for color augmentation
            if kwargs.get('color_augmentation', False):
                self._num_transformers += 1
            # Add 1 transformer for random scale
            if kwargs.get('max_random_scale', 1.0) - \
                kwargs.get('min_random_scale', 1.0) != 0:
                self._num_transformers += 1
            # Add 1 transformer for random crop
            if kwargs.get('crop_size', 0) > 0 and \
                kwargs.get('phase', 'TEST') == 'TRAIN':
                self._num_transformers += 1
        self._num_transformers = min(self._num_transformers,
                                     self._max_transformers)

        self._batch_size = kwargs.get('batch_size', 100)
        self._partition = kwargs.get('partition', False)
        if self._partition:
            self._batch_size = int(self._batch_size / kwargs['group_size'])

        # Init queues
        self.Q_level_1 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_2 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_3 = Queue(self._prefetch * self._num_readers)

        # Init readers
        self._readers = []
        for i in range(self._num_readers):
            self._readers.append(DataReader(**kwargs))
            self._readers[-1].Q_out = self.Q_level_1

        for i in range(self._num_readers):
            num_parts = self._num_readers
            part_idx = i

            if self._readers[i]._multiple_nodes or \
                    self._readers[i]._use_shuffle:
                num_parts *= group_size
                part_idx += local_rank * self._num_readers

            self._readers[i]._num_parts = num_parts
            self._readers[i]._part_idx = part_idx
            self._readers[i]._random_seed += part_idx
            self._readers[i].start()
            time.sleep(0.1)

        # Init transformers
        self._transformers = []
        for i in range(self._num_transformers):
            transformer = DataTransformer(**kwargs)
            transformer._random_seed += (i +
                                         local_rank * self._num_transformers)
            transformer.Q_in = self.Q_level_1
            transformer.Q_out = self.Q_level_2
            transformer.start()
            self._transformers.append(transformer)
            time.sleep(0.1)

        # Init blob fetchers
        self._fetchers = []
        for i in range(self._num_fetchers):
            fetcher = BlobFetcher(**kwargs)
            fetcher.Q_in = self.Q_level_2
            fetcher.Q_out = self.Q_level_3
            fetcher.start()
            self._fetchers.append(fetcher)
            time.sleep(0.1)

        # Prevent to echo multiple nodes
        if local_rank == 0: self.echo()

        def cleanup():
            def terminate(processes):
                for process in processes:
                    process.terminate()
                    process.join()

            from dragon.config import logger
            terminate(self._fetchers)
            if local_rank == 0: logger.info('Terminating BlobFetcher ......')
            terminate(self._transformers)
            if local_rank == 0:
                logger.info('Terminating DataTransformer ......')
            terminate(self._readers)
            if local_rank == 0: logger.info('Terminating DataReader......')

        import atexit
        atexit.register(cleanup)
Esempio n. 12
0
    def __init__(self, **kwargs):
        """Construct a ``DataBatch``.

        Parameters
        ----------
        source : str
            The path of database.
        multiple_nodes: boolean, optional, default=False
            Whether to split data for multiple parallel nodes.
        shuffle : bool, optional, default=False
            Whether to shuffle the data.
        num_chunks : int, optional, default=2048
            The number of chunks to split.
        padding : int, optional, default=0
            The zero-padding size.
        fill_value : int, optional, default=127
            The value to fill when padding is valid.
        crop_size : int, optional, default=0
            The cropping size.
        cutout_size : int, optional, default=0
            The square size to cutout.
        mirror : bool, optional, default=False
            Whether to mirror(flip horizontally) images.
        color_augmentation : bool, optional, default=False
            Whether to use color distortion.1
        min_random_scale : float, optional, default=1.
            The min scale of the input images.
        max_random_scale : float, optional, default=1.
            The max scale of the input images.
        force_gray : bool, optional, default=False
            Set not to duplicate channel for gray.
        phase : {'TRAIN', 'TEST'}, optional
            The optional running phase.
        batch_size : int, optional, default=128
            The size of a mini-batch.
        partition : bool, optional, default=False
            Whether to partition batch for parallelism.
        prefetch : int, optional, default=5
            The prefetch count.

        """
        super(DataBatch, self).__init__()
        # Init mpi
        global_rank, local_rank, group_size = 0, 0, 1
        if _mpi.Is_Init() and kwargs.get(
                'phase', 'TRAIN') == 'TRAIN':
            rank, group = _mpi.AllowParallel()
            if rank != -1: # DataParallel
                global_rank, group_size = _mpi.Rank(), len(group)
                for i, node in enumerate(group):
                    if global_rank == node: local_rank = i
        kwargs['group_size'] = group_size

        # configuration
        self._prefetch = kwargs.get('prefetch', 5)
        self._num_readers = kwargs.get('num_readers', 1)
        self._num_transformers = kwargs.get('num_transformers', -1)
        self._max_transformers = kwargs.get('max_transformers', 3)
        self._num_fetchers = kwargs.get('num_fetchers', 1)

        # io-aware policy
        if self._num_transformers == -1:
            self._num_transformers = 1
            # add 1 transformer for color augmentation
            if kwargs.get('color_augmentation', False):
                self._num_transformers += 1
            # add 1 transformer for random scale
            if kwargs.get('max_random_scale', 1.0) - \
                    kwargs.get('min_random_scale', 1.0) != 0:
                self._num_transformers += 1
        self._num_transformers = min(
            self._num_transformers, self._max_transformers)

        self._batch_size = kwargs.get('batch_size', 128)
        self._partition = kwargs.get('partition', False)
        if self._partition: self._batch_size //= kwargs['group_size']

        # init queues
        self.Q1 = Queue(self._prefetch * self._num_readers * self._batch_size)
        self.Q2 = Queue(self._prefetch * self._num_readers * self._batch_size)
        self.Q3 = Queue(self._prefetch * self._num_readers)

        # init readers
        self._readers = []
        for i in range(self._num_readers):
            self._readers.append(DataReader(**kwargs))
            self._readers[-1].Q_out = self.Q1

        for i in range(self._num_readers):
            part_idx, num_parts = i, self._num_readers
            if self._readers[i]._multi_nodes or \
                    self._readers[i]._use_shuffle:
                num_parts *= group_size
                part_idx += local_rank * self._num_readers
            self._readers[i]._num_parts = num_parts
            self._readers[i]._part_idx = part_idx
            self._readers[i]._rng_seed += part_idx
            self._readers[i].start()
            time.sleep(0.1)

        # Init transformers
        self._transformers = []
        for i in range(self._num_transformers):
            transformer = DataTransformer(**kwargs)
            transformer._rng_seed += (i + local_rank * self._num_transformers)
            transformer.Q_in, transformer.Q_out = self.Q1, self.Q2
            transformer.start()
            self._transformers.append(transformer)
            time.sleep(0.1)

        # Init blob fetchers
        self._fetchers = []
        for i in range(self._num_fetchers):
            fetcher = BlobFetcher(**kwargs)
            fetcher.Q_in, fetcher.Q_out = self.Q2, self.Q3
            fetcher.start()
            self._fetchers.append(fetcher)
            time.sleep(0.1)

        def cleanup():
            def terminate(processes):
                for process in processes:
                    process.terminate()
                    process.join()
            terminate(self._fetchers)
            if local_rank == 0: _logging.info('Terminate BlobFetcher.')
            terminate(self._transformers)
            if local_rank == 0: _logging.info('Terminate DataTransformer.')
            terminate(self._readers)
            if local_rank == 0: _logging.info('Terminate DataReader.')
        import atexit
        atexit.register(cleanup)
Esempio n. 13
0
    def __init__(self, **kwargs):
        """Construct a ``DataBatch``.

        Parameters
        ----------
        source : str
            The path of database.
        shuffle : boolean
            Whether to shuffle the data.
        node_step: boolean
            Whether to split data for multiple parallel nodes.
        num_chunks : int
            The number of chunks to split. Default is ``2048``.
        chunk_size : int
            The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
        mean_values : list
            The mean value of each image channel.
        scale : float
            The scale performed after mean subtraction. Default is ``1.0``.
        padding : int
            The zero-padding size. Default is ``0`` (Disabled).
        fill_value : int
            The value to fill when padding is valid. Default is ``127``.
        crop_size : int
            The crop size. Default is ``0`` (Disabled).
        mirror : boolean
            Whether to flip(horizontally) images. Default is ``False``.
        color_augmentation : boolean
            Whether to distort colors. Default is ``False``.
        min_random_scale : float
            The min scale of the input images. Default is ``1.0``.
        max_random_scale : float
            The max scale of the input images. Default is ``1.0``.
        force_color : boolean
            Set to duplicate channels for gray. Default is ``False``.
        phase : str
            The phase of this operator, ``TRAIN`` or ``TEST``. Default is ``TRAIN``.
        batch_size : int
            The size of a training batch.
        partition : boolean
            Whether to partition batch. Default is ``False``.
        prefetch : int
            The prefetch count. Default is ``5``.

        """
        super(DataBatch, self).__init__()
        # init mpi
        global_rank = 0
        local_rank = 0
        group_size = 1
        if mpi.Is_Init():
            idx, group = mpi.AllowParallel()
            if idx != -1:  # data parallel
                global_rank = mpi.Rank()
                group_size = len(group)
                for i, node in enumerate(group):
                    if global_rank == node: local_rank = i
        kwargs['group_size'] = group_size

        # configuration
        self._prefetch = GetProperty(kwargs, 'prefetch', 5)
        self._num_readers = GetProperty(kwargs, 'num_readers', 1)
        self._num_transformers = GetProperty(kwargs, 'num_transformers', -1)
        self._max_transformers = GetProperty(kwargs, 'max_transformers', 3)
        self._num_fetchers = GetProperty(kwargs, 'num_fetchers', 1)

        # io-aware policy
        if self._num_transformers == -1:
            self._num_transformers = 2
            # add 1 transformer for color augmentation
            if GetProperty(kwargs, 'color_augmentation', False):
                self._num_transformers += 1
            # add 1 transformer for random scale
            if GetProperty(kwargs, 'max_random_scale', 1.0) - \
                    GetProperty(kwargs, 'min_random_scale', 1.0) != 0:
                self._num_transformers += 1
        self._num_transformers = min(self._num_transformers,
                                     self._max_transformers)

        self._batch_size = GetProperty(kwargs, 'batch_size', 100)
        self._partition = GetProperty(kwargs, 'partition', False)
        if self._partition:
            self._batch_size = int(self._batch_size / kwargs['group_size'])

        # init queues
        self.Q_level_1 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_2 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_3 = Queue(self._prefetch * self._num_readers)

        # init readers
        self._readers = []
        for i in xrange(self._num_readers):
            self._readers.append(DataReader(**kwargs))
            self._readers[-1].Q_out = self.Q_level_1

        for i in xrange(self._num_readers):
            num_parts = self._num_readers
            part_idx = i

            if self._readers[i]._use_shuffle \
                    or self._readers[i]._use_step:
                num_parts *= group_size
                part_idx += local_rank * self._num_readers

            self._readers[i]._num_parts = num_parts
            self._readers[i]._part_idx = part_idx
            self._readers[i]._random_seed += part_idx
            self._readers[i].start()
            time.sleep(0.1)

        # init transformers
        self._transformers = []
        for i in xrange(self._num_transformers):
            transformer = DataTransformer(**kwargs)
            transformer._random_seed += (i +
                                         local_rank * self._num_transformers)
            transformer.Q_in = self.Q_level_1
            transformer.Q_out = self.Q_level_2
            transformer.start()
            self._transformers.append(transformer)
            time.sleep(0.1)

        # init blob fetchers
        self._fetchers = []
        for i in xrange(self._num_fetchers):
            fetcher = BlobFetcher(**kwargs)
            fetcher.Q_in = self.Q_level_2
            fetcher.Q_out = self.Q_level_3
            fetcher.start()
            self._fetchers.append(fetcher)
            time.sleep(0.1)

        self.echo()