Пример #1
0
def GraphDef_Update(graph_def, updater):
    """ generate all update targets for CC Graph """
    if updater is None: return

    updater._prefix = graph_def.name + '_'
    extra_kwargs = updater._extra_kwargs
    extra_kwargs['domain'] = updater._prefix

    # wrap hyper-parameters as Tensor for CC
    for k, v in updater._hyper_params.items():
        ws.FeedTensor(updater._prefix + k, np.array([v], dtype=np.float32))

    # check data parallel if necessary
    if mpi.is_init():
        idx, group = mpi.allow_parallel()
        if idx != -1:
            extra_kwargs['comm'], extra_kwargs['group'] \
                = mpi.group(root=group[0], incl=group)
            extra_kwargs['root'] = group[0]
            extra_kwargs['mode'] = mpi.get_parallel_mode()
            extra_kwargs['group_size'] = len(group)

    for tuple in updater._tuples:
        tensors = tuple[0]
        kwargs = tuple[1]
        kwargs = dict(kwargs, **extra_kwargs)
        u_target = pb.UpdateTarget()
        u_target.type = updater._type
        _, u_target.name = GetOperatorName()
        for tensor in tensors:
            u_target.tensor.append(tensor)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        graph_def.u_target.extend([u_target])
Пример #2
0
def Restore(filename, format=0):
    if mpi.is_init():
        if not mpi.allow_snapshot():
            if not mpi.allow_parallel():
                filename += '.rank.{}'.format(mpi.rank())
                return

    assert os.path.exists(
        filename), 'model of path({}) does not exist.'.format(filename)
    if format is 0:
        content = cPickle.load(open(filename, 'rb'))
        logger.info('Restore From Model@: ' + filename)
        logger.info('Model Format: cPickle')
        for key, ndarray in content.items():
            if not HasTensor(key):
                logger.info(
                    '[Warning]:  Tensor({}) of model does not exist in any Graphs, skip.'
                    .format(key))
            else:
                logger.info('[Info]: Tensor({}) restored.'.format(key))
                FeedTensor(key, ndarray)

    elif format is 1:
        # TODO(PhyscalX): caffemodel can't save the tensor name
        # TODO(PhyscalX): we simply use 'Scope + LayerName + @paramX'
        RestoreCC(filename, '', format)
Пример #3
0
    def InitTestNets(self):
        if mpi.is_init():
            idx, group = mpi.allow_parallel()
            # only the root in a parallel group can test
            if idx != -1 and mpi.rank() != group[0]: return

        num_test_net = len(self._param.test_iter)
        if num_test_net > 0:
            if self._param.test_interval <= 0:
                raise RuntimeError('the val of test interval: {} is invaild.')

        if len(self._param.test_net) > 0:
            for test_net in self._param.test_net:
                 self._test_nets.append(Net(test_net, "TEST"))
            num_test_net -= len(self._param.test_net)

        # consider generic_net
        if num_test_net > 0:
            self._test_nets.append(Net(self._param.net, "TEST"))

        # share with training net
        for test_net in self._test_nets: test_net.share_with(self._net)
Пример #4
0
    def __init__(self, **kwargs):
        super(DataBatch, self).__init__()
        """DataBatch use Triple-Buffering to speed up"""

        # init mpi
        global_rank = 0
        local_rank = 0
        group_size = 1
        if mpi.is_init():
            idx, group = mpi.allow_parallel()
            if idx != -1:  # data parallel
                global_rank = mpi.rank()
                group_size = len(group)
                for i, node in enumerate(group):
                    if global_rank == node: local_rank = i
        kwargs['group_size'] = group_size

        # configuration
        self._prefetch = GetProperty(kwargs, 'prefetch', 5)
        self._num_readers = GetProperty(kwargs, 'num_readers', 1)
        self._num_transformers = GetProperty(kwargs, 'num_transformers', -1)
        self._num_fetchers = GetProperty(kwargs, 'num_fetchers', 1)

        # default policy
        if self._num_transformers == -1:
            self._num_transformers = 1
            # add 1 transformer for color augmentation
            if GetProperty(kwargs, 'color_augmentation', False):
                self._num_transformers += 1
            # add 1 transformer for random scale
            if GetProperty(kwargs, 'max_random_scale', 1.0) - \
                GetProperty(kwargs, 'min_random_scale', 1.0) != 0:
                self._num_transformers += 1

        self._batch_size = GetProperty(kwargs, 'batch_size', 100)
        self._partition = GetProperty(kwargs, 'partition', False)
        if self._partition:
            self._batch_size = int(self._batch_size / kwargs['group_size'])

        # init queues
        self.Q_level_1 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_2 = Queue(self._prefetch * self._num_readers *
                               self._batch_size)
        self.Q_level_3 = Queue(self._prefetch * self._num_readers)
        self.Q_level_4 = Queue2(self._prefetch * self._num_readers)

        # init readers
        self._readers = []
        for i in xrange(self._num_readers):
            self._readers.append(DataReader(**kwargs))
            self._readers[-1].Q_out = self.Q_level_1

        for i in xrange(self._num_readers):
            num_parts = self._num_readers
            part_idx = i

            if self._readers[i]._use_shuffle \
                    or self._readers[i]._use_step:
                num_parts *= group_size
                part_idx += local_rank * self._num_readers

            self._readers[i]._num_parts = num_parts
            self._readers[i]._part_idx = part_idx
            self._readers[i]._random_seed += part_idx
            self._readers[i].start()
            time.sleep(0.1)

        # init transformers
        self._transformers = []
        for i in xrange(self._num_transformers):
            transformer = DataTransformer(**kwargs)
            transformer.Q_in = self.Q_level_1
            transformer.Q_out = self.Q_level_2
            transformer.start()
            self._transformers.append(transformer)
            time.sleep(0.1)

        # init blob fetchers
        self._fetchers = []
        for i in xrange(self._num_fetchers):
            fetcher = BlobFetcher(**kwargs)
            fetcher.Q_in = self.Q_level_2
            fetcher.Q_out = self.Q_level_3
            fetcher.start()
            self._fetchers.append(fetcher)
            time.sleep(0.1)

        self.daemon = True
        self.start()