def allreduce_params(self):
     for i, param in enumerate(self._params):
         if param.grad_req != 'null':
             hvd.allreduce_(param.list_data()[0],
                            average=True,
                            name=str(i),
                            priority=-i)
예제 #2
0
    def test_horovod_allreduce_inplace(self):
        """Test that the allreduce correctly sums 1D, 2D, 3D tensors."""
        hvd.init()
        size = hvd.size()
        dtypes = self.filter_supported_types(
            ['int32', 'int64', 'float32', 'float64'])
        dims = [1, 2, 3]
        ctx = self._current_context()
        count = 0
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for dtype, dim in itertools.product(dtypes, dims):
            mx.random.seed(1234, ctx=ctx)
            tensor = mx.nd.random.uniform(-100,
                                          100,
                                          shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)
            multiplied = tensor * size
            hvd.allreduce_(tensor, average=False, name=str(count))
            count += 1

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in ['int32', 'int64']:
                threshold = 0
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            assert almost_equal(tensor.asnumpy(), multiplied.asnumpy(), atol=threshold), \
                f'hvd.allreduce produces incorrect results for self: {hvd.rank()} {count} {dtype} {dim}'
예제 #3
0
 def pre_test(self):
     for i, param in enumerate(self._params):
         if param.grad_req != 'null':
             self._params_cache[i][:] = param.list_data()[0]
             hvd.allreduce_(param.list_data()[0],
                            average=True,
                            name=str(i),
                            priority=-i)
    def _allreduce_params(self):

        for i, param in enumerate(self._params):
            if param.grad_req != 'null':
                hvd.allreduce_(param.list_data()[0],
                               average=True,
                               name=str(i),
                               priority=-i)

                # communication counter
                self._comm_counter += param.list_data()[0].size * 2
예제 #5
0
    def _do_allreduce(self, index, grad):
        if hvd.size() == 1:
            return

        if isinstance(index, (tuple, list)):
            for i in range(len(index)):
                hvd.allreduce_(grad[i],
                               average=False,
                               name=self._prefix + str(index[i]),
                               priority=-i)
        else:
            hvd.allreduce_(grad, average=False, name=self._prefix + str(index))
 def allreduce_states(self):
     for i, param in reversed(list(enumerate(self._params))):
         if param.grad_req != 'null':
             state_array = self._updaters[0].states[i][1]
             idx = i + len(self._params)
             if param._stype == 'default':
                 hvd.allreduce_(state_array,
                                average=True,
                                name=str(idx),
                                priority=i - len(self._params) * 2)
                 self._updaters[0].states[i][0][:] = state_array
             else:
                 raise ValueError(
                     "Cannot pull row_sparse parameters for local SGD")
예제 #7
0
 def allgather(self, tensor, name, shape, dtype, context):
     """ Implement in-place AllGather using AllReduce
     """
     assert isinstance(tensor, nd.NDArray),          type(tensor)
     assert isinstance(name, str),                   type(name)
     assert isinstance(shape, tuple),                type(shape)
     assert isinstance(dtype, str),                  type(dtype)
     assert isinstance(context, mx.context.Context), type(context)
     total_tensor = self.get_ndarray(
         context=context, name=name, shape=shape, dtype=dtype)
     total_tensor[:] = 0                                      # reset array before all-reduce is very important
     total_tensor[self.rank * self.batch_size:
                  self.rank * self.batch_size + self.batch_size] = tensor
     hvd.allreduce_(total_tensor, average=False)              # all-reduce in-place
     return total_tensor
 def allreduce_params(self):
     """For each parameter, reduce the parameters from different contexts.
     Should be called after `autograd.backward()`, outside of `record()` scope,
     and before `trainer.update()`.
     For normal parameter updates, `step()` should be used, which internally calls
     `allreduce_grads()` and then `update()`. However, if you need to get the reduced
     gradients to perform certain transformation, such as in gradient clipping, then
     you may want to manually call `allreduce_grads()` and `update()` separately.
     """
     for i, param in enumerate(self._params):
         if param.grad_req != 'null':
             hvd.allreduce_(param.list_data()[0],
                            average=True,
                            name=str(i),
                            priority=-i)
             for j in range(1, len(param.list_data())):
                 param.list_data()[0].copyto(param.list_data()[j])
예제 #9
0
    def test_horovod_allreduce_inplace(self):
        """Test that the allreduce correctly sums 1D, 2D, 3D tensors."""
        hvd.init()
        size = hvd.size()
        dtypes = self.filter_supported_types(
            ['int32', 'int64', 'float32', 'float64'])
        dims = [1, 2, 3]
        ctx = self._current_context()
        count = 0
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for dtype, dim in itertools.product(dtypes, dims):
            mx.random.seed(1234, ctx=ctx)
            tensor = mx.nd.random.uniform(-100,
                                          100,
                                          shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)
            multiplied = tensor * size
            hvd.allreduce_(tensor, average=False, name=str(count))
            max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied))
            count += 1

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in ['int32', 'int64']:
                threshold = 0
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            if max_difference > threshold:
                print("self", count, dtype, dim, max_difference, threshold)
                print("tensor", hvd.rank(), tensor)
                print("multiplied", hvd.rank(), multiplied)
            assert max_difference <= threshold, 'hvd.allreduce produces \
예제 #10
0
    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params == '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        hvd.broadcast_parameters(net.collect_params(), root_rank=0)

        # trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)

        # trainer = hvd.DistributedTrainer(
        #     net.collect_params(),  
        #     optimizer,
        #     optimizer_params)

        if opt.trainer == 'sgd':
            trainer = SGDTrainer(
                net.collect_params(),  
                optimizer, optimizer_params)
        elif opt.trainer == 'efsgd':
            trainer = EFSGDTrainerV1(
                net.collect_params(),  
                'EFSGDV1', optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1)
        elif opt.trainer == 'qsparselocalsgd':
            trainer = QSparseLocalSGDTrainerV1(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1,
                local_sgd_interval=opt.local_sgd_interval)
        elif opt.trainer == 'ersgd':
            trainer = ERSGDTrainerV2(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1)
        elif opt.trainer == 'partiallocalsgd':
            trainer = PartialLocalSGDTrainerV1(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1,
                local_sgd_interval=opt.local_sgd_interval)
        elif opt.trainer == 'ersgd2':
            trainer = ERSGD2TrainerV2(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio_1=1./opt.input_sparse_1, 
                output_sparse_ratio_1=1./opt.output_sparse_1, 
                layer_sparse_ratio_1=1./opt.layer_sparse_1,
                input_sparse_ratio_2=1./opt.input_sparse_2, 
                output_sparse_ratio_2=1./opt.output_sparse_2, 
                layer_sparse_ratio_2=1./opt.layer_sparse_2,
                local_sgd_interval=opt.local_sgd_interval)
        else:
            trainer = SGDTrainer(
                net.collect_params(),  
                optimizer, optimizer_params)

        if opt.resume_states != '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature,
                                                                 hard_weight=opt.hard_weight,
                                                                 sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)

        best_val_score = 1

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            # train_metric.reset()
            train_loss = 0
            btic = time.time()

            # test speed
            if opt.test_speed > 0:
                n_repeats = opt.test_speed
            elif opt.test_speed == 0:
                n_repeats = 1
            else:
                n_repeats = 0

            for i, batch in enumerate(train_data):
                
                # test speed
                if n_repeats == 0 and not (i+1)%opt.log_interval:
                    print('[Epoch %d] # batch: %d'%(epoch, i))
                    continue

                data, label = batch_fn(batch, ctx)

                for j in range(n_repeats):

                    if opt.mixup:
                        lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                        if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                            lam = 1
                        data = [lam*X + (1-lam)*X[::-1] for X in data]

                        if opt.label_smoothing:
                            eta = 0.1
                        else:
                            eta = 0.0
                        label = mixup_transform(label, classes, lam, eta)

                    elif opt.label_smoothing:
                        hard_label = label
                        label = smooth(label, classes)

                    if distillation:
                        teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                        for X in data]

                    with ag.record():
                        outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                        if distillation:
                            loss = [L(yhat.astype('float32', copy=False),
                                    y.astype('float32', copy=False),
                                    p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
                        else:
                            loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
                    for l in loss:
                        l.backward()
                    trainer.step(batch_size)

                    # if opt.mixup:
                    #     output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                    #                     for out in outputs]
                    #     train_metric.update(label, output_softmax)
                    # else:
                    #     if opt.label_smoothing:
                    #         train_metric.update(hard_label, outputs)
                    #     else:
                    #         train_metric.update(label, outputs)

                    step_loss = sum([l.sum().asscalar() for l in loss])

                    train_loss += step_loss

                    if opt.log_interval and not (i+j+1)%opt.log_interval:
                        # train_metric_name, train_metric_score = train_metric.get()
                        if hvd.rank() == 0:
                            # logger.info('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%(
                            #             epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic),
                            #             train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6))
                            # print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%(
                            #             epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic),
                            #             train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6))
                            print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%(
                                        epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic),
                                        'loss', step_loss/batch_size, trainer.learning_rate, trainer._comm_counter/1e6))
                        btic = time.time()

            mx.nd.waitall()
            toc = time.time()

            if n_repeats == 0:
                allreduce_array_nd = mx.nd.array([i])
                hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True)
                mx.nd.waitall()
                print('[Epoch %d] # total batch: %d'%(epoch, i))
                continue

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i /(toc - tic) * hvd.size())

            train_loss /= (batch_size * i)

            if opt.trainer == 'ersgd' or opt.trainer == 'qsparselocalsgd' or opt.trainer == 'ersgd2' or opt.trainer == 'partiallocalsgd':
                allreduce_for_val = True
            else:
                allreduce_for_val = False

            if allreduce_for_val:
                trainer.pre_test()
            # err_train_tic = time.time()
            # err_top1_train, err_top5_train = test(ctx, train_data, val=False)
            err_val_tic = time.time()
            err_top1_val, err_top5_val = test(ctx, val_data, val=True)
            err_val_toc = time.time()
            if allreduce_for_val:
                trainer.post_test()

            mx.nd.waitall()

            # allreduce the results
            allreduce_array_nd = mx.nd.array([train_loss, err_top1_val, err_top5_val])
            hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True)
            allreduce_array_np = allreduce_array_nd.asnumpy()
            train_loss = np.asscalar(allreduce_array_np[0])
            err_top1_val = np.asscalar(allreduce_array_np[1])
            err_top5_val = np.asscalar(allreduce_array_np[2])

            if hvd.rank() == 0:
                # logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
                logger.info('[Epoch %d] training: loss=%f'%(epoch, train_loss))
                logger.info('[Epoch %d] speed: %d samples/sec training-time: %f comm: %f'%(epoch, throughput, toc-tic, trainer._comm_counter/1e6))
                logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f err-time=%f'%(epoch, err_top1_val, err_top5_val, err_val_toc - err_val_tic))
                trainer._comm_counter = 0

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                # if hvd.local_rank() == 0:
                #     net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
                #     trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
                if hvd.local_rank() == 0:
                    net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
                    trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))
예제 #11
0
    def test_gluon_trainer(self):
        """Test using horovod allreduce in MXNet Gluon trainer."""
        from mxnet import gluon
        from mxnet.gluon import Block, nn, HybridBlock

        hvd.init()
        rank = hvd.rank()
        np.random.seed(1000 + 10 * rank)
        mx.random.seed(1000 + 10 * rank)
        ctx = mx.gpu(rank)

        def gen_random_dataset(batch_size=64,
                               dim=32,
                               min_len=20,
                               max_len=100,
                               size=1000):
            for _ in range(size):
                length = np.random.randint(min_len, max_len + 1)
                rand_src = mx.nd.random.normal(0, 1, (length, dim))
                rand_dst = mx.nd.random.normal(0, 1, (length, dim))
                yield rand_src, rand_dst

        class SimpleNet(HybridBlock):
            def __init__(self, layer_num=6, **kwargs):
                super(SimpleNet, self).__init__(**kwargs)
                self._layer_num = layer_num
                with self.name_scope():
                    self.ln_l = nn.HybridSequential()
                    self.dense_l = nn.HybridSequential()
                    for i in range(layer_num):
                        self.dense_l.add(
                            nn.Dense(units=32 + layer_num - 1 - i,
                                     flatten=False))
                        self.ln_l.add(nn.LayerNorm())

            def hybrid_forward(self, F, data):
                """

                Parameters
                ----------
                data :
                    Shape (batch_size, seq_len, fea_dim)

                Returns
                -------
                out :
                    Shape (batch_size, seq_len, fea_dim)
                """
                for i in range(self._layer_num):
                    data = self.ln_l[i](data)
                    data = self.dense_l[i](data)
                return data

        net = SimpleNet()
        net.initialize(ctx=ctx)
        net.hybridize(static_alloc=True)

        params = net.collect_params()
        cnt = 0
        lr = 1E-4
        trainer = gluon.Trainer(params,
                                'adam', {'learning_rate': lr},
                                update_on_kvstore=False)

        data_gen = gen_random_dataset()
        for (src_data, dst_data) in data_gen:
            src_data = src_data.as_in_context(ctx).astype(np.float32)
            dst_data = dst_data.as_in_context(ctx).astype(np.float32)
            with mx.autograd.record():
                pred = net(src_data)
                loss = mx.nd.abs(pred - dst_data).mean()
                loss.backward()
            # Begin to update the parameter
            trainer.step(1.0)
            cnt += 1
            l = loss.asscalar()
            if cnt >= 10:
                for key, param in params.items():
                    hvd.allreduce_(param.list_data()[0])
                cnt = 0
    def pushpull(self, key, value, out=None, priority=0):
        """ Performs allreduce on a single tensor or a list of tensor objects

        This function performs in-place summation of the input tensor over all the processes.

        The name `pushpull` is a generic term. In Horovod, its action is implemented via
        ring allreduce. Each operation is identified by the 'key'; if `key` is not provided, an
        incremented auto-generated name is used. The tensor type and shape must be
        the same on all processes for a given name. The reduction will not start until all processes
        are ready to send and receive the tensor.

        Parameters
        ----------
        key : str, int, or sequence of str or int
            Keys used to uniquely tag an operation.

        value : NDArray
            Tensor value on one process to be summed. If `out` is not specified, the `value` will
            be modified in-place

        out: NDArray
            Output tensor after allreduce. If not specified, the input tensor `value` will be
            modified in-place.

        priority : int, optional
            The priority of the operation.
            Higher priority operations are likely to be executed before other actions.

        Examples
        --------
        >>> # perform in-place allreduce on tensor a
        >>> shape = (2, 3)
        >>> nworker = kv.num_workers # assume there are 8 processes
        >>> a = mx.nd.ones(shape)
        >>> kv.pushpull('1', a)
        >>> print(a.asnumpy())
        [[ 8.  8.  8.]
        [ 8.  8.  8.]]

        >>> # perform allreduce on tensor a and output to b
        >>> a = mx.nd.ones(shape)
        >>> kv.pushpull('2', a, out=b)
        >>> print(b.asnumpy())
        [[ 8.  8.  8.]
        [ 8.  8.  8.]]
        """
        import horovod.mxnet as hvd

        if out is None:
            value = value if isinstance(value, list) else [value]
            for v in value:
                hvd.allreduce_(v,
                               average=False,
                               name=str(key),
                               priority=priority)
        else:
            out = out if isinstance(out, list) else [out]
            value = value if isinstance(value, list) else [value]
            for o, v in zip(out, value):
                o[:] = hvd.allreduce(v,
                                     average=False,
                                     name=str(key),
                                     priority=priority)
예제 #13
0
    def backward_sample(self, total_feature, label):
        this_rank_classes = int(self.memory_bank.num_sample)
        local_index, unique_sorted_global_label = self.memory_bank.sample(
            label)

        # Get local index
        _mapping_dict = {}
        local_sampled_class = local_index + self.rank * self.memory_bank.num_local
        global_label_set = set(unique_sorted_global_label)
        for idx, absolute_label in enumerate(local_sampled_class):
            if absolute_label in global_label_set:
                _mapping_dict[
                    absolute_label] = idx + self.rank * self.memory_bank.num_sample

        label_list = list(label.asnumpy())
        mapping_label = []
        for i in range(len(label_list)):
            absolute_label = label_list[i]
            if absolute_label in _mapping_dict.keys():
                mapping_label.append(_mapping_dict[absolute_label])
            else:
                mapping_label.append(-1)

        mapping_label = nd.array(mapping_label, dtype=np.int32)

        # Get weight
        local_index = nd.array(local_index)
        local_index = self.get_ndarray2(self.gpu, "local_index", local_index)
        sample_weight, sample_weight_mom = self.memory_bank.get(local_index)

        # Sync to gpu
        if self.memory_bank.gpu:
            _data = self.get_ndarray2(self.gpu, "data_%d" % self.rank,
                                      total_feature)
            _weight = self.get_ndarray2(self.gpu, 'weight_%d' % self.rank,
                                        sample_weight)
            _weight_mom = self.get_ndarray2(self.gpu,
                                            'weight_mom_%d' % self.rank,
                                            sample_weight_mom)
        else:
            _data = self.get_ndarray2(self.gpu, "data_%d" % self.rank,
                                      total_feature)
            _weight = self.get_ndarray2(self.gpu, 'weight_%d' % self.rank,
                                        sample_weight)
            _weight_mom = self.get_ndarray2(self.gpu,
                                            'weight_mom_%d' % self.rank,
                                            sample_weight_mom)

        # Attach grad
        _data.attach_grad()
        _weight.attach_grad()

        # Convert label
        _label = self.get_ndarray2(self.gpu, 'mapping_label_%d' % self.rank,
                                   mapping_label)
        _label = _label - int(self.rank * self.memory_bank.num_sample)
        _fc7, _one_hot = self.fc7_model.forward(_data,
                                                _weight,
                                                mapping_label=_label,
                                                depth=this_rank_classes)

        # Sync max
        max_fc7 = nd.max(_fc7, axis=1, keepdims=True)
        max_fc7 = nd.reshape(max_fc7, -1)

        total_max_fc7 = self.get_ndarray(context=self.gpu,
                                         name='total_max_fc7',
                                         shape=(max_fc7.shape[0], self.size),
                                         dtype='float32')
        total_max_fc7[:] = 0
        total_max_fc7[:, self.rank] = max_fc7
        hvd.allreduce_(total_max_fc7, average=False)

        global_max_fc7 = self.get_ndarray(context=self.gpu,
                                          name='global_max_fc7',
                                          shape=(max_fc7.shape[0], 1),
                                          dtype='float32')
        nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7)

        # Calculate exp(logits)
        _fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7)
        _fc7_grad = nd.exp(_fc7_grad)

        # Calculate sum
        sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True)
        global_sum_fc7 = hvd.allreduce(sum_fc7, average=False)

        # Calculate grad
        _fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7)

        # Calculate loss
        tmp = _fc7_grad * _one_hot
        tmp = nd.sum(tmp, axis=1, keepdims=True)
        tmp = self.get_ndarray2(self.gpu, 'ctx_loss', tmp)
        tmp = hvd.allreduce(tmp, average=False)
        global_loss = -nd.mean(nd.log(tmp + 1e-30))

        _fc7_grad = _fc7_grad - _one_hot

        # Backward
        _fc7.backward(out_grad=_fc7_grad)

        # Update center
        _weight_grad = _weight.grad
        self.memory_optimizer.update(weight=_weight,
                                     grad=_weight_grad,
                                     state=_weight_mom,
                                     learning_rate=self.memory_lr)
        if self.memory_bank.gpu:
            self.memory_bank.set(index=local_index,
                                 updated_weight=_weight,
                                 updated_weight_mom=_weight_mom)
        else:
            self.memory_bank.set(index=local_index,
                                 updated_weight=self.get_ndarray2(
                                     mx.cpu(), "cpu_weight_%d" % self.rank,
                                     _weight),
                                 updated_weight_mom=self.get_ndarray2(
                                     mx.cpu(), "cpu_weight_mom_%d" % self.rank,
                                     _weight_mom))
        return _data.grad, global_loss
예제 #14
0
    def backward(self, total_feature, label):
        memory_bank = self.memory_bank
        assert memory_bank.num_local == memory_bank.num_sample, "pass"

        _data = self.get_ndarray2(self.gpu, "data_%d" % self.rank,
                                  total_feature)
        # Attach grad
        _data.attach_grad()
        memory_bank.weight.attach_grad()

        # Convert label
        _label = self.get_ndarray2(self.gpu, 'label_%d' % self.rank, label)
        _label = _label - int(self.rank * memory_bank.num_local)
        _fc7, _one_hot = self.fc7_model.forward(_data,
                                                memory_bank.weight,
                                                mapping_label=_label,
                                                depth=memory_bank.num_local)

        # Sync max
        max_fc7 = nd.max(_fc7, axis=1, keepdims=True)
        max_fc7 = nd.reshape(max_fc7, -1)

        total_max_fc7 = self.get_ndarray(context=self.gpu,
                                         name='total_max_fc7',
                                         shape=(max_fc7.shape[0], self.size),
                                         dtype='float32')
        total_max_fc7[:] = 0
        total_max_fc7[:, self.rank] = max_fc7
        hvd.allreduce_(total_max_fc7, average=False)

        global_max_fc7 = self.get_ndarray(context=self.gpu,
                                          name='global_max_fc7',
                                          shape=(max_fc7.shape[0], 1),
                                          dtype='float32')
        nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7)

        # Calculate exp(logits)
        _fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7)
        _fc7_grad = nd.exp(_fc7_grad)

        # Calculate sum
        sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True)
        global_sum_fc7 = hvd.allreduce(sum_fc7, average=False)

        # Calculate prob
        _fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7)

        # Calculate loss
        tmp = _fc7_grad * _one_hot
        tmp = nd.sum(tmp, axis=1, keepdims=True)
        tmp = self.get_ndarray2(self.gpu, 'ctx_loss', tmp)
        tmp = hvd.allreduce(tmp, average=False)
        global_loss = -nd.mean(nd.log(tmp + 1e-30))

        # Calculate fc7 grad
        _fc7_grad = _fc7_grad - _one_hot

        # Backward
        _fc7.backward(out_grad=_fc7_grad)

        # Update center
        _weight_grad = memory_bank.weight.grad
        self.memory_optimizer.update(weight=memory_bank.weight,
                                     grad=_weight_grad,
                                     state=memory_bank.weight_mom,
                                     learning_rate=self.memory_lr)

        return _data.grad, global_loss
예제 #15
0
def train(data_train, data_eval, model):
    """Training function."""
    # backend specific implementation
    param_dict = model.bert.collect_params()
    if backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True
    if args.optimizer == 'lamb':
        optim_params['bias_correction'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {'scale_window': 2000 / num_workers, 'init_scale': 1}
    else:
        loss_scale_param = None

    # backend specific implementation
    if backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params)
    elif backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer, optim_params)
    else:
        trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params,
                                   update_on_kvstore=False)
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0])
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')
    logging.info('Generating the first batch of data, which may take a few minutes ...')

    # create dummy data loader if needed
    parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
    num_ctxes = len(ctxs)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model)

    if backend == 'byteps':
        bps.byteps_declare_tensor("local_num_masks")
        bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0)
        logging.debug('Broadcast local_num_masks tensor')
        next_batch = next(iter(get_dummy_dataloader(batch_size, args.max_seq_length, args.max_predictions_per_seq)))
        data_list = list(split_and_load(next_batch, ctxs))
        parallel.put(data_list[0])
        parallel.get()
        trainer._init_params()

    while step_num < num_train_steps:

        data_train_iter = iter(data_train)
        end_of_batch = False
        next_data_batch = next(data_train_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                if accumulate > 1:
                    param_dict.zero_grad()
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = lr * step_num / num_train_steps
                    new_lr = lr - offset
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num, 10, 14, profile_name=args.profile + str(rank))
                if early_stop and step_num == 10:
                    mx.nd.waitall()
                    exit()

            # load data
            data_list = list(split_and_load(data_batch, ctxs))

            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            with mx.autograd.record():
                num_data = len(data_list)
                for i in range(num_data):
                    parallel.put(data_list[i])
                for _ in range(num_data):
                    (next_sentence_label, classified, masked_id,
                     decoded, masked_weight, ls1, ls2, valid_length, num_masks) = parallel.get()
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(classified)
                    mask_label_list.append(masked_id)
                    mask_pred_list.append(decoded)
                    mask_weight_list.append(masked_weight)
                    local_num_masks += num_masks
                    local_mlm_loss += ls1
                    running_num_tks += valid_length.sum()
            # pre fetch next batch
            try:
                next_data_batch = next(data_train_iter)
            except StopIteration:
                end_of_batch = True

            # update
            if (batch_num + 1) % accumulate == 0:
                running_mlm_loss += local_mlm_loss / local_num_masks
                if backend == 'horovod':
                    hvd.allreduce_(local_num_masks, average=False, name='local_num_masks')
                elif backend == 'byteps':
                    bps.byteps_push_pull(local_num_masks, is_average=False,
                                         name="local_num_masks", priority=0)
                # because byteps implicitly set scale /= num_workers
                fp16_trainer.step(local_num_masks * num_workers, max_norm=local_num_masks,
                                  num_ctxs=len(ctxs) * num_workers)
                local_num_masks, local_mlm_loss = 0, 0
            # update metrics
            if args.no_compute_acc:
                for mask_pred_i in mask_pred_list:
                    mask_pred_i.wait_to_read()
            else:
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)

            # logging
            if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0:
                if args.no_compute_acc:
                    log_noacc(begin_time, running_num_tks, running_mlm_loss,
                              0, step_num, trainer, args.log_interval)
                else:
                    log(begin_time, running_num_tks, running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric,
                        trainer, args.log_interval)
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0

            # saving checkpoints
            if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0:
#                if is_master_node:
#                    save_states(step_num, trainer, args.ckpt_dir, local_rank)
#                    if local_rank == 0:
#                        save_parameters(step_num, model.bert, args.ckpt_dir)
                if (step_num + 1) % args.eval_interval == 0 and data_eval:
                    # eval data is always based on a fixed npz file.
                    dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval,
                                                         1, False, 1, vocab)
                    evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, rank, num_workers)

            batch_num += 1

#    if is_master_node:
#        save_states(step_num, trainer, args.ckpt_dir, local_rank)
#        if local_rank == 0:
#            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        # if opt.print_tensor_shape and rank == 0:
        #     print(net)

        train_dataset = gluon.data.vision.CIFAR100(train=True).transform_first(transform_train)

        train_data = gluon.data.DataLoader(
            train_dataset,
            sampler=SplitSampler(len(train_dataset), num_parts=num_workers, part_index=rank),
            batch_size=batch_size, last_batch='discard', num_workers=opt.num_workers)

        # val_dataset = gluon.data.vision.CIFAR100(train=False).transform_first(transform_test)
        # val_data = gluon.data.DataLoader(
        #     val_dataset,
        #     sampler=SplitSampler(len(val_dataset), num_parts=num_workers, part_index=rank),
        #     batch_size=batch_size, num_workers=opt.num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR100(train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=opt.num_workers)

        hvd.broadcast_parameters(net.collect_params(), root_rank=0)

        trainer = QSparseLocalSGDTrainerV1(
            net.collect_params(),  
            'nag', optimizer_params, 
            input_sparse_ratio=1./opt.input_sparse, 
            output_sparse_ratio=1./opt.output_sparse, 
            layer_sparse_ratio=1./opt.layer_sparse,
            local_sgd_interval=opt.local_sgd_interval)

        # trainer = gluon.Trainer(net.collect_params(), optimizer,
                                # {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
        
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        lr = opt.lr

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                lr *= lr_decay
                trainer.set_learning_rate(lr)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]

                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1

            mx.nd.waitall()
            toc = time.time()
            
            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            # name, val_acc = test(ctx, val_data)

            trainer.pre_test()
            name, val_acc = test(ctx, val_data)
            trainer.post_test()
            
            train_history.update([1-acc, 1-val_acc])
            # train_history.plot(save_path='%s/%s_history.png'%(plot_path, model_name))

            # allreduce the results
            allreduce_array_nd = mx.nd.array([train_loss, acc, val_acc])
            hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True)
            allreduce_array_np = allreduce_array_nd.asnumpy()
            train_loss = np.asscalar(allreduce_array_np[0])
            acc = np.asscalar(allreduce_array_np[1])
            val_acc = np.asscalar(allreduce_array_np[2])

            if val_acc > best_val_score:
                best_val_score = val_acc
                # net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

            if rank == 0:
                logging.info('[Epoch %d] train=%f val=%f loss=%f comm=%.2f time: %f' %
                    (epoch, acc, val_acc, train_loss, trainer._comm_counter/1e6, toc-tic))

                if save_period and save_dir and (epoch + 1) % save_period == 0:
                    net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))

            trainer._comm_counter = 0.

        if rank == 0:
            if save_period and save_dir:
                net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))
def train():
    """Training loop for language model.
    """
    # logging.info(model)
    from_epoch = 0
    model.initialize(mx.init.Xavier(factor_type='out'), ctx=ctx)
    trainer_params = {'learning_rate': args.lr, 'wd': 0, 'eps': args.eps}
    # trainer = gluon.Trainer(model.collect_params(), args.optimizer, trainer_params)
    # fully sync at the beginning
    trainer = DistributedHierLocalHVDTrainer(model.collect_params(),
                                             args.optimizer,
                                             trainer_params,
                                             local_sgd_interval=0)
    trainer._optimizer._full_sync = True

    if args.from_epoch:
        from_epoch = args.from_epoch
        checkpoint_name = '%s.%s' % (args.save, format(from_epoch - 1, '02d'))
        model.load_parameters(checkpoint_name)
        trainer.load_states('%s.state' % args.save)
        logging.info('Loaded parameters from checkpoint %s' %
                     (checkpoint_name))

    hvd.broadcast_parameters(model.collect_params(), root_rank=0)

    model.hybridize(static_alloc=True, static_shape=True)
    encoder_params = model.encoder.collect_params().values()
    embedding_params = list(model.embedding.collect_params().values())

    step_num = 0
    lr = args.lr
    current_lr = lr

    epoch = from_epoch
    start_epoch_time = time.time()
    start_log_interval_time = time.time()
    nbatch = 0

    while epoch < args.epochs:
        sys.stdout.flush()
        total_L = 0.0
        hidden = model.begin_state(batch_size=args.batch_size,
                                   func=mx.nd.zeros,
                                   ctx=ctx)
        has_next = True
        train_data_iter = iter(train_data)
        data, target, mask, sample = next(train_data_iter)

        while has_next:
            nbatch += 1

            step_num += 1
            if step_num <= args.warmup_steps:
                new_lr = lr * step_num / args.warmup_steps
                trainer.set_learning_rate(new_lr)
                current_lr = new_lr

            if step_num == args.warmup_steps + 1:
                trainer._local_sgd_interval = args.local_sgd_interval
                trainer._optimizer._full_sync = False
                trainer.init_states()

            hidden = detach(hidden)

            with autograd.record():
                output, hidden, new_target = model(data, target, hidden,
                                                   sample)
                output = output.reshape((-3, -1))
                new_target = new_target.reshape((-1, ))
                ls = loss(output, new_target) * mask.reshape((-1, ))
                ls = ls / args.batch_size
                ls.backward()

            # prefetch the next batch of data
            try:
                data, target, mask, sample = next(train_data_iter)
            except StopIteration:
                has_next = False

            # rescale embedding grad
            x = embedding_params[0].grad(ctx)
            x[:] *= args.batch_size
            encoder_grad = [p.grad(ctx) for p in encoder_params]
            # perform gradient clipping per ctx
            gluon.utils.clip_global_norm(encoder_grad, args.clip)

            trainer.step(1)

            ls_sum = mx.nd.sum(ls)

            total_L += ls_sum / args.bptt

            # total_L += mx.nd.sum(ls).asscalar() / args.bptt

            if nbatch % args.log_interval == 0:

                hvd.allreduce_(total_L,
                               average=True,
                               name='ls',
                               priority=-9999)

                cur_L = total_L.asscalar() / args.log_interval
                ppl = math.exp(cur_L) if cur_L < 100 else float('inf')
                if rank == 0:
                    logging.info(
                        '[Epoch %d Batch %d] loss %.2f, ppl %.2f, '
                        'throughput %.2f samples/s, lr %.4f' %
                        (epoch, nbatch, cur_L, ppl,
                         train_batch_size * num_workers * args.log_interval /
                         (time.time() - start_log_interval_time), current_lr))
                total_L = 0.0
                start_log_interval_time = time.time()
                sys.stdout.flush()

            if nbatch == num_batches_per_epoch:
                end_epoch_time = time.time()
                logging.info('Epoch %d took %.2f seconds.' %
                             (epoch, end_epoch_time - start_epoch_time))
                mx.nd.waitall()
                checkpoint_name = '%s.%s' % (args.save, format(epoch, '02d'))
                if local_rank == 0:
                    model.save_parameters(checkpoint_name)
                if local_rank == 1:
                    trainer.save_states('%s.state' % args.save)
                nbatch = 0
                start_epoch_time = time.time()
                epoch += 1
                if epoch == args.epochs:
                    break