Exemple #1
0
    def InitTestNets(self):
        """Initialize the test nets.

        Returns
        -------
        None

        References
        ----------
        The implementation of `InitTestNets(solver.cpp, L104)`_.

        """
        if mpi.Is_Init():
            idx, group = mpi.AllowParallel()
            # only the root in a parallel group can test
            if idx != -1 and mpi.Rank() != group[0]: return

        num_test_net = len(self._param.test_iter)
        if num_test_net > 0:
            if self._param.test_interval <= 0:
                raise RuntimeError('the val of test interval: {} is invaild.')

        if len(self._param.test_net) > 0:
            for test_net in self._param.test_net:
                 self._test_nets.append(Net(test_net, "TEST"))
            num_test_net -= len(self._param.test_net)

        # consider generic_net
        if num_test_net > 0:
            self._test_nets.append(Net(self._param.net, "TEST"))

        # share with training net
        for test_net in self._test_nets: test_net.share_with(self._net)
Exemple #2
0
def Snapshot(
    tensors,
    filename,
    prefix='',
    suffix='.bin',
    format='default',
):
    """Snapshot tensors into a binary file.

    Parameters
    ----------
    tensors : list of Tensor or Tensor
        The tensors to be wrote.
    filename : str
        The name of this binary file.
    prefix : str
        The prefix of this binary file.
    suffix : str
        The suffix of this binary file.
    format : str
        The format of this binary file.

    Returns
    -------
    None

    Notes
    -----
    The full file path will be:  ``prefix`` + ``filename`` + ``suffix``.

    Available formats: ['default', 'caffe'].

    """
    file_path = prefix + filename + suffix
    if mpi.Is_Init():
        if not mpi.AllowSnapshot(): return
        file_path = file_path + '.rank.{}'.format(mpi.Rank())

    dir = os.path.split(file_path)[0]
    if len(dir) > 0 and not os.path.exists(dir): os.makedirs(dir)

    if format == 'default':
        state_dict = {}
        for tensor in tensors:
            state_dict[tensor.name] = FetchTensor(tensor)
        with open(file_path, 'wb') as f:
            pickle.dump(state_dict, f, pickle.HIGHEST_PROTOCOL)
        logging.info('Snapshot Model@: ' + file_path)
        logging.info('Model Format: Pickle')
    elif format is 'caffe':
        names = [tensor.name for tensor in tensors]
        _C.Snapshot(file_path, names, 1)
    else:
        raise TypeError('Unknown binary format: {}'.format(format))
Exemple #3
0
def Snapshot(
    tensors,
    filename,
    prefix='',
    suffix='.bin',
    format='pickle',
):
    """Serialize tensors into a binary file.

    The filename is formatted as:

        ``prefix`` + ``filename`` + ``suffix``

    Parameters
    ----------
    tensors : list of Tensor or Tensor
        The tensors to be wrote.
    filename : str
        The name of this binary file.
    prefix : str, optional, default=''
        The prefix of this binary file.
    suffix : str, optional, default='.bin'
        The suffix of this binary file.
    format : {'pickle', 'caffe'}, optional
        The format of this binary file.

    Returns
    -------
    None

    """
    file_path = prefix + filename + suffix

    if _mpi.Is_Init():
        if not _mpi.AllowSnapshot(): return
        file_path = file_path + '.rank.{}'.format(_mpi.Rank())

    dir = os.path.split(file_path)[0]
    if len(dir) > 0 and not os.path.exists(dir): os.makedirs(dir)

    if format == 'pickle':
        state_dict = {}
        for tensor in tensors:
            state_dict[tensor.name] = FetchTensor(tensor)
        with open(file_path, 'wb') as f:
            pickle.dump(state_dict, f, pickle.HIGHEST_PROTOCOL)
        _logging.info('Snapshot Model@: ' + file_path)
        _logging.info('Model Format: Pickle')
    elif format == 'caffe':
        names = [tensor.name for tensor in tensors]
        get_default_workspace().Snapshot(file_path, names, 1)
    else:
        raise TypeError('Unknown binary format: ' + format)
Exemple #4
0
 def register_op(self):
     idx, group = mpi.AllowParallel()
     if idx == -1:
         raise RuntimeError('The mpi node({}) dost not in '
             'parallel groups. \nSet it using mpi.Parallel([..]).'.format(mpi.Rank()))
     mpi_comm, mpi_group = mpi.CreateGroup(root=group[0], incl=group)
     self.op_meta = {
         'op_type': 'CollectiveUpdate',
         'arguments': {
             'mode': self.mode,
             'comm': mpi_comm,
             'group': mpi_group,
             'root': group[0], # Assume the 1st node of group as root
         },
     }
Exemple #5
0
    def __init__(self, **kwargs):
        """Construct a ``DataBatch``.

        Parameters
        ----------
        source : str
            The path of database.
        multiple_nodes: boolean
            Whether to split data for multiple parallel nodes. Default is ``False``.
        shuffle : boolean
            Whether to shuffle the data. Default is ``False``.
        num_chunks : int
            The number of chunks to split. Default is ``2048``.
        chunk_size : int
            The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
        padding : int
            The zero-padding size. Default is ``0`` (Disabled).
        fill_value : int
            The value to fill when padding is valid. Default is ``127``.
        crop_size : int
            The crop size. Default is ``0`` (Disabled).
        mirror : boolean
            Whether to flip(horizontally) images. Default is ``False``.
        color_augmentation : boolean
            Whether to distort colors. Default is ``False``.
        min_random_scale : float
            The min scale of the input images. Default is ``1.0``.
        max_random_scale : float
            The max scale of the input images. Default is ``1.0``.
        force_color : boolean
            Set to duplicate channels for gray. Default is ``False``.
        phase : str
            The phase of this operator, ``TRAIN`` or ``TEST``. Default is ``TRAIN``.
        batch_size : int
            The size of a training batch.
        partition : boolean
            Whether to partition batch. Default is ``False``.
        prefetch : int
            The prefetch count. Default is ``5``.

        """
        super(DataBatch, self).__init__()
        # Init mpi
        global_rank = 0
        local_rank = 0
        group_size = 1
        if mpi.Is_Init():
            idx, group = mpi.AllowParallel()
            if idx != -1:  # DataParallel
                global_rank = mpi.Rank()
                group_size = len(group)
                for i, node in enumerate(group):
                    if global_rank == node: local_rank = i
        kwargs['group_size'] = group_size

        # Configuration
        self._prefetch = kwargs.get('prefetch', 5)
        self._num_readers = kwargs.get('num_readers', 1)
        self._num_transformers = kwargs.get('num_transformers', -1)
        self._max_transformers = kwargs.get('max_transformers', 3)
        self._num_fetchers = kwargs.get('num_fetchers', 1)

        # Io-Aware Policy
        if self._num_transformers == -1:
            self._num_transformers = 1
            # Add 1 transformer for color augmentation
            if kwargs.get('color_augmentation', False):
                self._num_transformers += 1
            # Add 1 transformer for random scale
            if kwargs.get('max_random_scale', 1.0) - \
                kwargs.get('min_random_scale', 1.0) != 0:
                self._num_transformers += 1
            # Add 1 transformer for random crop
            if kwargs.get('crop_size', 0) > 0 and \
                kwargs.get('phase', 'TEST') == 'TRAIN':
                self._num_transformers += 1
        self._num_transformers = min(self._num_transformers,
                                     self._max_transformers)

        self._batch_size = kwargs.get('batch_size', 100)
        self._partition = kwargs.get('partition', False)
        if self._partition:
            self._batch_size = int(self._batch_size / kwargs['group_size'])

        # Init queues
        self.Q_level_1 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_2 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_3 = Queue(self._prefetch * self._num_readers)

        # Init readers
        self._readers = []
        for i in range(self._num_readers):
            self._readers.append(DataReader(**kwargs))
            self._readers[-1].Q_out = self.Q_level_1

        for i in range(self._num_readers):
            num_parts = self._num_readers
            part_idx = i

            if self._readers[i]._multiple_nodes or \
                    self._readers[i]._use_shuffle:
                num_parts *= group_size
                part_idx += local_rank * self._num_readers

            self._readers[i]._num_parts = num_parts
            self._readers[i]._part_idx = part_idx
            self._readers[i]._random_seed += part_idx
            self._readers[i].start()
            time.sleep(0.1)

        # Init transformers
        self._transformers = []
        for i in range(self._num_transformers):
            transformer = DataTransformer(**kwargs)
            transformer._random_seed += (i +
                                         local_rank * self._num_transformers)
            transformer.Q_in = self.Q_level_1
            transformer.Q_out = self.Q_level_2
            transformer.start()
            self._transformers.append(transformer)
            time.sleep(0.1)

        # Init blob fetchers
        self._fetchers = []
        for i in range(self._num_fetchers):
            fetcher = BlobFetcher(**kwargs)
            fetcher.Q_in = self.Q_level_2
            fetcher.Q_out = self.Q_level_3
            fetcher.start()
            self._fetchers.append(fetcher)
            time.sleep(0.1)

        # Prevent to echo multiple nodes
        if local_rank == 0: self.echo()

        def cleanup():
            def terminate(processes):
                for process in processes:
                    process.terminate()
                    process.join()

            from dragon.config import logger
            terminate(self._fetchers)
            if local_rank == 0: logger.info('Terminating BlobFetcher ......')
            terminate(self._transformers)
            if local_rank == 0:
                logger.info('Terminating DataTransformer ......')
            terminate(self._readers)
            if local_rank == 0: logger.info('Terminating DataReader......')

        import atexit
        atexit.register(cleanup)
cfg.IMS_PER_BATCH = cfg.IMS_PER_BATCH / len(gpus)

if __name__ == '__main__':

    # fix the random seeds (numpy and caffe) for reproducibility
    np.random.seed(cfg.RNG_SEED)
    caffe.set_random_seed(cfg.RNG_SEED)

    # setup caffe
    caffe.set_mode_gpu()

    # setup mpi
    if len(gpus) != mpi.Size():
        raise ValueError('Excepted {} mpi nodes, but got {}.'.format(
            len(gpus), mpi.Size()))
    caffe.set_device(gpus[mpi.Rank()])
    mpi.Parallel([i for i in xrange(len(gpus))])
    mpi.Snapshot([0])
    if mpi.Rank() != 0:
        caffe.set_root_solver(False)

    # setup database
    cfg.DATABASE = imdb_name
    imdb = get_imdb(imdb_name)
    print 'Database({}): {} images will be used to train.'.format(
        cfg.DATABASE, imdb.db_size)
    output_dir = osp.abspath(
        osp.join(cfg.ROOT_DIR, 'output', cfg.EXP_DIR, args.imdb_name))
    print 'Output will be saved to `{:s}`'.format(output_dir)

    # train net
Exemple #7
0
    def __init__(self, **kwargs):
        """Construct a ``DataBatch``.

        Parameters
        ----------
        source : str
            The path of database.
        multiple_nodes: boolean, optional, default=False
            Whether to split data for multiple parallel nodes.
        shuffle : bool, optional, default=False
            Whether to shuffle the data.
        num_chunks : int, optional, default=2048
            The number of chunks to split.
        padding : int, optional, default=0
            The zero-padding size.
        fill_value : int, optional, default=127
            The value to fill when padding is valid.
        crop_size : int, optional, default=0
            The cropping size.
        cutout_size : int, optional, default=0
            The square size to cutout.
        mirror : bool, optional, default=False
            Whether to mirror(flip horizontally) images.
        color_augmentation : bool, optional, default=False
            Whether to use color distortion.1
        min_random_scale : float, optional, default=1.
            The min scale of the input images.
        max_random_scale : float, optional, default=1.
            The max scale of the input images.
        force_gray : bool, optional, default=False
            Set not to duplicate channel for gray.
        phase : {'TRAIN', 'TEST'}, optional
            The optional running phase.
        batch_size : int, optional, default=128
            The size of a mini-batch.
        partition : bool, optional, default=False
            Whether to partition batch for parallelism.
        prefetch : int, optional, default=5
            The prefetch count.

        """
        super(DataBatch, self).__init__()
        # Init mpi
        global_rank, local_rank, group_size = 0, 0, 1
        if _mpi.Is_Init() and kwargs.get(
                'phase', 'TRAIN') == 'TRAIN':
            rank, group = _mpi.AllowParallel()
            if rank != -1: # DataParallel
                global_rank, group_size = _mpi.Rank(), len(group)
                for i, node in enumerate(group):
                    if global_rank == node: local_rank = i
        kwargs['group_size'] = group_size

        # configuration
        self._prefetch = kwargs.get('prefetch', 5)
        self._num_readers = kwargs.get('num_readers', 1)
        self._num_transformers = kwargs.get('num_transformers', -1)
        self._max_transformers = kwargs.get('max_transformers', 3)
        self._num_fetchers = kwargs.get('num_fetchers', 1)

        # io-aware policy
        if self._num_transformers == -1:
            self._num_transformers = 1
            # add 1 transformer for color augmentation
            if kwargs.get('color_augmentation', False):
                self._num_transformers += 1
            # add 1 transformer for random scale
            if kwargs.get('max_random_scale', 1.0) - \
                    kwargs.get('min_random_scale', 1.0) != 0:
                self._num_transformers += 1
        self._num_transformers = min(
            self._num_transformers, self._max_transformers)

        self._batch_size = kwargs.get('batch_size', 128)
        self._partition = kwargs.get('partition', False)
        if self._partition: self._batch_size //= kwargs['group_size']

        # init queues
        self.Q1 = Queue(self._prefetch * self._num_readers * self._batch_size)
        self.Q2 = Queue(self._prefetch * self._num_readers * self._batch_size)
        self.Q3 = Queue(self._prefetch * self._num_readers)

        # init readers
        self._readers = []
        for i in range(self._num_readers):
            self._readers.append(DataReader(**kwargs))
            self._readers[-1].Q_out = self.Q1

        for i in range(self._num_readers):
            part_idx, num_parts = i, self._num_readers
            if self._readers[i]._multi_nodes or \
                    self._readers[i]._use_shuffle:
                num_parts *= group_size
                part_idx += local_rank * self._num_readers
            self._readers[i]._num_parts = num_parts
            self._readers[i]._part_idx = part_idx
            self._readers[i]._rng_seed += part_idx
            self._readers[i].start()
            time.sleep(0.1)

        # Init transformers
        self._transformers = []
        for i in range(self._num_transformers):
            transformer = DataTransformer(**kwargs)
            transformer._rng_seed += (i + local_rank * self._num_transformers)
            transformer.Q_in, transformer.Q_out = self.Q1, self.Q2
            transformer.start()
            self._transformers.append(transformer)
            time.sleep(0.1)

        # Init blob fetchers
        self._fetchers = []
        for i in range(self._num_fetchers):
            fetcher = BlobFetcher(**kwargs)
            fetcher.Q_in, fetcher.Q_out = self.Q2, self.Q3
            fetcher.start()
            self._fetchers.append(fetcher)
            time.sleep(0.1)

        def cleanup():
            def terminate(processes):
                for process in processes:
                    process.terminate()
                    process.join()
            terminate(self._fetchers)
            if local_rank == 0: _logging.info('Terminate BlobFetcher.')
            terminate(self._transformers)
            if local_rank == 0: _logging.info('Terminate DataTransformer.')
            terminate(self._readers)
            if local_rank == 0: _logging.info('Terminate DataReader.')
        import atexit
        atexit.register(cleanup)
Exemple #8
0
    def __init__(self, **kwargs):
        """Construct a ``DataBatch``.

        Parameters
        ----------
        source : str
            The path of database.
        shuffle : boolean
            Whether to shuffle the data.
        node_step: boolean
            Whether to split data for multiple parallel nodes.
        num_chunks : int
            The number of chunks to split. Default is ``2048``.
        chunk_size : int
            The size(MB) of each chunk. Default is -1 (Refer ``num_chunks``).
        mean_values : list
            The mean value of each image channel.
        scale : float
            The scale performed after mean subtraction. Default is ``1.0``.
        padding : int
            The zero-padding size. Default is ``0`` (Disabled).
        fill_value : int
            The value to fill when padding is valid. Default is ``127``.
        crop_size : int
            The crop size. Default is ``0`` (Disabled).
        mirror : boolean
            Whether to flip(horizontally) images. Default is ``False``.
        color_augmentation : boolean
            Whether to distort colors. Default is ``False``.
        min_random_scale : float
            The min scale of the input images. Default is ``1.0``.
        max_random_scale : float
            The max scale of the input images. Default is ``1.0``.
        force_color : boolean
            Set to duplicate channels for gray. Default is ``False``.
        phase : str
            The phase of this operator, ``TRAIN`` or ``TEST``. Default is ``TRAIN``.
        batch_size : int
            The size of a training batch.
        partition : boolean
            Whether to partition batch. Default is ``False``.
        prefetch : int
            The prefetch count. Default is ``5``.

        """
        super(DataBatch, self).__init__()
        # init mpi
        global_rank = 0
        local_rank = 0
        group_size = 1
        if mpi.Is_Init():
            idx, group = mpi.AllowParallel()
            if idx != -1:  # data parallel
                global_rank = mpi.Rank()
                group_size = len(group)
                for i, node in enumerate(group):
                    if global_rank == node: local_rank = i
        kwargs['group_size'] = group_size

        # configuration
        self._prefetch = GetProperty(kwargs, 'prefetch', 5)
        self._num_readers = GetProperty(kwargs, 'num_readers', 1)
        self._num_transformers = GetProperty(kwargs, 'num_transformers', -1)
        self._max_transformers = GetProperty(kwargs, 'max_transformers', 3)
        self._num_fetchers = GetProperty(kwargs, 'num_fetchers', 1)

        # io-aware policy
        if self._num_transformers == -1:
            self._num_transformers = 2
            # add 1 transformer for color augmentation
            if GetProperty(kwargs, 'color_augmentation', False):
                self._num_transformers += 1
            # add 1 transformer for random scale
            if GetProperty(kwargs, 'max_random_scale', 1.0) - \
                    GetProperty(kwargs, 'min_random_scale', 1.0) != 0:
                self._num_transformers += 1
        self._num_transformers = min(self._num_transformers,
                                     self._max_transformers)

        self._batch_size = GetProperty(kwargs, 'batch_size', 100)
        self._partition = GetProperty(kwargs, 'partition', False)
        if self._partition:
            self._batch_size = int(self._batch_size / kwargs['group_size'])

        # init queues
        self.Q_level_1 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_2 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_3 = Queue(self._prefetch * self._num_readers)

        # init readers
        self._readers = []
        for i in xrange(self._num_readers):
            self._readers.append(DataReader(**kwargs))
            self._readers[-1].Q_out = self.Q_level_1

        for i in xrange(self._num_readers):
            num_parts = self._num_readers
            part_idx = i

            if self._readers[i]._use_shuffle \
                    or self._readers[i]._use_step:
                num_parts *= group_size
                part_idx += local_rank * self._num_readers

            self._readers[i]._num_parts = num_parts
            self._readers[i]._part_idx = part_idx
            self._readers[i]._random_seed += part_idx
            self._readers[i].start()
            time.sleep(0.1)

        # init transformers
        self._transformers = []
        for i in xrange(self._num_transformers):
            transformer = DataTransformer(**kwargs)
            transformer._random_seed += (i +
                                         local_rank * self._num_transformers)
            transformer.Q_in = self.Q_level_1
            transformer.Q_out = self.Q_level_2
            transformer.start()
            self._transformers.append(transformer)
            time.sleep(0.1)

        # init blob fetchers
        self._fetchers = []
        for i in xrange(self._num_fetchers):
            fetcher = BlobFetcher(**kwargs)
            fetcher.Q_in = self.Q_level_2
            fetcher.Q_out = self.Q_level_3
            fetcher.start()
            self._fetchers.append(fetcher)
            time.sleep(0.1)

        self.echo()