示例#1
0
    def test_horovod_allreduce_error(self):
        """Test that the allreduce raises an error if different ranks try to
           send tensors of different rank or dimension."""
        hvd.init()
        rank = hvd.rank()
        size = hvd.size()

        # This test does not apply if there is only one worker.
        if size == 1:
            self.skipTest("Only one worker available")

        # Same rank, different dimension
        ctx = self._current_context()

        shape = (17 + rank, 3)
        tensor = mx.nd.ones(shape=shape, ctx=ctx)
        try:
            output = hvd.allreduce(tensor)
            output.wait_to_read()
            assert False, 'hvd.allreduce did not throw error'
        except (MXNetError, RuntimeError):
            pass

        # Same number of elements, different rank
        if rank == 0:
            shape = (17, 23 * 57)
        else:
            shape = (17, 23, 57)
        tensor = mx.nd.ones(shape=shape, ctx=ctx)
        try:
            output = hvd.allreduce(tensor)
            output.wait_to_read()
            assert False, 'hvd.allreduce did not throw error'
        except (MXNetError, RuntimeError):
            pass
示例#2
0
    def __call__(self, param):
        num_update = param.num_update

        if num_update in [self.max_step - 10, ] or (num_update % 10000 == 0 and num_update > 0):

            # params
            arg, aux = self.model.get_export_params()
            # symbol
            _sym = self.symbol
            # save

            # average all aux
            new_arg, new_aux = {}, {}
            for key, tensor in aux.items():
                new_aux[key] = hvd.allreduce(tensor, average=True)
            for key, tensor in arg.items():
                new_arg[key] = hvd.allreduce(tensor, average=True)

            if self.rank == 0:
                mx.model.save_checkpoint(
                    prefix=self.prefix + "_average",
                    epoch=0, symbol=_sym,
                    arg_params=new_arg,
                    aux_params=new_aux)
                mx.model.save_checkpoint(
                    prefix=self.prefix,
                    epoch=0, symbol=_sym,
                    arg_params=arg,
                    aux_params=aux)

        # training is over
        if num_update > self.max_step > 0:
            logging.info('Training is over!')
            sys.exit(0)
示例#3
0
    def test_horovod_allreduce_average(self):
        """Test that the allreduce correctly sums 1D, 2D, 3D tensors."""
        hvd.init()
        size = hvd.size()
        dtypes = self.filter_supported_types(
            ['int32', 'int64', 'float32', 'float64'])
        dims = [1, 2, 3]
        ctx = self._current_context()
        count = 0
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for dtype, dim in itertools.product(dtypes, dims):
            mx.random.seed(1234, ctx=ctx)
            tensor = mx.nd.random.uniform(-100,
                                          100,
                                          shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)
            averaged = hvd.allreduce(tensor, average=True, name=str(count))
            tensor *= size
            tensor /= size
            count += 1

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in ['int32', 'int64']:
                threshold = 1
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            assert almost_equal(averaged.asnumpy(), tensor.asnumpy(), atol=threshold), \
                f'hvd.allreduce produces incorrect results for average: {hvd.rank()} {count} {dtype} {dim}'
示例#4
0
    def test_horovod_allreduce_postscale(self):
        """Test that the allreduce correctly sums 1D, 2D, 3D tensors with postscaling."""
        hvd.init()
        size = hvd.size()
        dtypes = self.filter_supported_types(
            ['int32', 'int64', 'float16', 'float32', 'float64'])
        int_types = ['int32', 'int64']
        dims = [1, 2, 3]
        ctx = self._current_context()
        count = 1
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for dtype, dim in itertools.product(dtypes, dims):
            mx.random.seed(1234, ctx=ctx)
            np.random.seed(1234)
            tensor = mx.nd.random.uniform(-100,
                                          100,
                                          shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)
            factor = np.random.uniform()
            scaled = hvd.allreduce(tensor,
                                   average=False,
                                   name=str(count),
                                   postscale_factor=factor)

            factor = mx.nd.array([factor], dtype='float64', ctx=ctx)
            if ctx != mx.cpu() and not int(
                    os.environ.get('HOROVOD_MIXED_INSTALL', 0)):
                # For integer types, scaling done in FP64
                factor = factor.astype('float64' if dtype in
                                       int_types else dtype)
                tensor = tensor.astype('float64' if dtype in
                                       int_types else dtype)
            else:
                # For integer types, scaling done in FP64, FP32 math for FP16 on CPU
                factor = factor.astype('float32' if dtype ==
                                       'float16' else 'float64' if dtype in
                                       int_types else dtype)
                tensor = tensor.astype('float32' if dtype ==
                                       'float16' else 'float64' if dtype in
                                       int_types else dtype)

            expected = tensor * size
            expected *= factor
            expected = expected.astype(dtype)
            count += 1

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in int_types:
                threshold = 0
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            assert almost_equal(expected.asnumpy(), scaled.asnumpy(), atol=threshold), \
                f'hvd.allreduce produces incorrect results for pre/post scaling: {hvd.rank()} {count} {dtype} {dim}'
示例#5
0
    def test_horovod_allreduce_type_error(self):
        """Test that the allreduce raises an error if different ranks try to
           send tensors of different type."""
        hvd.init()
        rank = hvd.rank()
        size = hvd.size()

        # This test does not apply if there is only one worker.
        if size == 1:
            self.skipTest("Only one worker available")

        ctx = self._current_context()
        shape = (17, 3)
        tensor = mx.nd.ones(shape=shape, ctx=ctx)
        if rank % 2 == 0:
            tensor = tensor.astype('int32')
        else:
            tensor = tensor.astype('float32')

        try:
            output = hvd.allreduce(tensor)
            output.wait_to_read()
            assert False, 'hvd.allreduce did not throw error'
        except (MXNetError, RuntimeError):
            pass
示例#6
0
    def test_horovod_allreduce_cpu_gpu_error(self):
        """Test that the allreduce raises an error if different ranks try to
           perform reduction on CPU and GPU."""
        if os.environ.get('HOROVOD_MIXED_INSTALL'):
            # Skip if compiled with CUDA but without HOROVOD_GPU_OPERATIONS.
            self.skipTest("Not compiled with HOROVOD_GPU_OPERATIONS")

        hvd.init()
        rank = hvd.rank()
        size = hvd.size()

        # This test does not apply if there is only one worker.
        if size == 1:
            self.skipTest("Only one worker available")

        shape = (17, 17, 17)
        if rank % 2 == 0:
            ctx = mx.gpu(hvd.rank())
        else:
            ctx = mx.cpu(hvd.rank())
        tensor = mx.nd.ones(shape=shape, ctx=ctx)

        try:
            output = hvd.allreduce(tensor)
            output.wait_to_read()
            assert False, 'hvd.allreduce did not throw cpu-gpu error'
        except (MXNetError, RuntimeError):
            pass
示例#7
0
    def backward_all(
        self,
        total_feature,
        total_label,
    ):
        # get memory bank learning rate
        self.memory_lr = self.memory_optimizer.lr_scheduler(self.num_update)

        self.grad_cache = self.get_ndarray(self.gpu, 'grad_cache',
                                           total_feature.shape)
        self.loss_cache = self.get_ndarray(self.gpu, 'loss_cache', [1])

        self.grad_cache[:] = 0
        self.loss_cache[:] = 0

        if not bool(config.sample_ratio - 1):
            grad, loss = self.backward(total_feature, total_label)
        else:
            grad, loss = self.backward_sample(total_feature, total_label)

        self.loss_cache[0] = loss

        total_feature_grad = grad
        total_feature_grad = hvd.allreduce(total_feature_grad, average=False)

        fc1_grad = total_feature_grad[self.batch_size *
                                      self.rank:self.batch_size * self.rank +
                                      self.batch_size]
        self.backbone_module.backward(out_grads=[fc1_grad / self.size])
示例#8
0
def reduce_metrics(args, metrics, kvstore):
    if 'horovod' not in kvstore or not metrics[0] or hvd.size() == 1:
        return metrics

    m = mx.ndarray.array(metrics[1], ctx=mx.gpu(args.gpus[0]))
    reduced = hvd.allreduce(m)
    values = reduced.as_in_context(mx.cpu()).asnumpy().tolist()
    return (metrics[0], values)
示例#9
0
 def allreduce_running(self):
     # allreduce running BN means and vars
     if hvd.size() > 1:
         for param_name, param in self.net.collect_params().items():
             if any(running_param in param_name
                    for running_param in self.RUNNING_PARAMS):
                 t = param.data(ctx=self.ctx)
                 t = hvd.allreduce(t, average=True, name=None, priority=0)
                 param.set_data(t)
示例#10
0
    def test_horovod_allreduce_ndarray_lifetime(self):
        """Test that the input NDArray remains valid during async allreduce"""
        hvd.init()
        rank = hvd.rank()
        size = hvd.size()

        dims = [1, 2, 3]
        ctx = self._current_context()
        count = 0
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for i, dim in enumerate(dims):
            tensor = mx.nd.ones(shape=shapes[dim], ctx=ctx)
            # tensor*(i+1) result will be destroyed immediately after this call
            # See https://github.com/horovod/horovod/issues/1533
            sum = hvd.allreduce(tensor * (i + 1), average=False)
            expected = tensor * (i + 1) * size
            assert same(sum.asnumpy(), expected.asnumpy())
示例#11
0
    def test_horovod_allreduce(self):
        """Test that the allreduce correctly sums 1D, 2D, 3D tensors."""
        hvd.init()
        size = hvd.size()
        dtypes = self.filter_supported_types(
            ['int32', 'int64', 'float32', 'float64'])
        dims = [1, 2, 3]
        ctx = self._current_context()
        count = 0
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for dtype, dim in itertools.product(dtypes, dims):
            # MXNet uses gpu_id as part of the seed, so to get identical seeds
            # we must set a context.
            mx.random.seed(1234, ctx=ctx)
            tensor = mx.nd.random.uniform(-100,
                                          100,
                                          shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)
            summed = hvd.allreduce(tensor, average=False, name=str(count))
            multiplied = tensor * size
            max_difference = mx.nd.max(mx.nd.subtract(summed, multiplied))
            count += 1

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in ['int32', 'int64']:
                threshold = 0
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            if max_difference > threshold:
                print("allreduce", count, dtype, dim, max_difference,
                      threshold)
                print("tensor", hvd.rank(), tensor)
                print("summed", hvd.rank(), summed)
                print("multiplied", hvd.rank(), multiplied)
            assert max_difference <= threshold, 'hvd.allreduce produces \
示例#12
0
    def backward_all(self, total_feature, total_label, ):
        # get memory bank learning rate
        self.memory_lr = self.memory_optimizer.lr_scheduler(self.num_update)

        # reverse shuffle bn
        total_feature = total_feature.reshape(-1, self.embedding_size * self.head_num)
        # global_label
        total_label = total_label.reshape(-1, self.head_num)
        #
        self.grad_cache = self.get_ndarray(self.gpu, 'grad_cache', total_feature.shape)
        self.loss_cache = self.get_ndarray(self.gpu, 'loss_cache', [self.head_num])

        self.grad_cache[:] = 0
        self.loss_cache[:] = 0

        for head_id in range(self.head_num):
            _fc1_one_head = total_feature[
                            :,
                            head_id * self.embedding_size:
                            head_id * self.embedding_size + self.embedding_size
                            ]
            _label_one_head = total_label[:, head_id]

            grad, loss = self.backward(head_id, _fc1_one_head, _label_one_head)
            self.grad_cache[
                :,
                head_id * self.embedding_size:
                head_id * self.embedding_size + self.embedding_size
            ] = grad
            self.loss_cache[head_id] = loss

        total_feature_grad = self.grad_cache.reshape(-1, self.embedding_size)
        total_feature_grad = hvd.allreduce(total_feature_grad, average=False)

        # self.bn_module.backward(out_grads=[total_feature_grad / self.backbone_grad_rescale])
        # bn_input_grad = self.bn_module.get_input_grads()[0]

        fc1_grad = total_feature_grad[
            self.batch_size * self.rank:
            self.batch_size * self.rank + self.batch_size
        ]
        self.backbone_module.backward(out_grads=[fc1_grad])
示例#13
0
    def test_horovod_allreduce_average(self):
        """Test that the allreduce correctly sums 1D, 2D, 3D tensors."""
        hvd.init()
        size = hvd.size()
        dtypes = ['int32', 'int64', 'float32', 'float64']
        dims = [1, 2, 3]
        ctx = self._current_context()
        count = 0
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for dtype, dim in itertools.product(dtypes, dims):
            mx.random.seed(1234, ctx=ctx)
            tensor = mx.nd.random.uniform(-100,
                                          100,
                                          shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)
            averaged = hvd.allreduce(tensor, average=True, name=str(count))
            tensor *= size
            tensor /= size
            max_difference = mx.nd.max(mx.nd.subtract(averaged, tensor))
            count += 1

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in ['int32', 'int64']:
                threshold = 1
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            if max_difference > threshold:
                print("average", count, dtype, dim, max_difference, threshold)
                print("tensor", hvd.rank(), tensor)
                print("averaged", hvd.rank(), averaged)
            assert max_difference <= threshold, 'hvd.allreduce produces \
示例#14
0
    def test_horovod_allreduce_cpu_gpu_error(self):
        """Test that the allreduce raises an error if different ranks try to
           perform reduction on CPU and GPU."""
        hvd.init()
        rank = hvd.rank()
        size = hvd.size()

        # This test does not apply if there is only one worker.
        if size == 1:
            return

        shape = (17, 17, 17)
        if rank % 2 == 0:
            ctx = mx.gpu(hvd.rank())
        else:
            ctx = mx.cpu(hvd.rank())
        tensor = mx.nd.ones(shape=shape, ctx=ctx)

        try:
            output = hvd.allreduce(tensor)
            output.wait_to_read()
            assert False, 'hvd.allreduce did not throw cpu-gpu error'
        except (MXNetError, RuntimeError):
            pass
示例#15
0
    def backward(self, total_feature, label):
        memory_bank = self.memory_bank
        assert memory_bank.num_local == memory_bank.num_sample, "pass"

        _data = self.get_ndarray2(self.gpu, "data_%d" % self.rank,
                                  total_feature)
        # Attach grad
        _data.attach_grad()
        memory_bank.weight.attach_grad()

        # Convert label
        _label = self.get_ndarray2(self.gpu, 'label_%d' % self.rank, label)
        _label = _label - int(self.rank * memory_bank.num_local)
        _fc7, _one_hot = self.fc7_model.forward(_data,
                                                memory_bank.weight,
                                                mapping_label=_label,
                                                depth=memory_bank.num_local)

        # Sync max
        max_fc7 = nd.max(_fc7, axis=1, keepdims=True)
        max_fc7 = nd.reshape(max_fc7, -1)

        total_max_fc7 = self.get_ndarray(context=self.gpu,
                                         name='total_max_fc7',
                                         shape=(max_fc7.shape[0], self.size),
                                         dtype='float32')
        total_max_fc7[:] = 0
        total_max_fc7[:, self.rank] = max_fc7
        hvd.allreduce_(total_max_fc7, average=False)

        global_max_fc7 = self.get_ndarray(context=self.gpu,
                                          name='global_max_fc7',
                                          shape=(max_fc7.shape[0], 1),
                                          dtype='float32')
        nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7)

        # Calculate exp(logits)
        _fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7)
        _fc7_grad = nd.exp(_fc7_grad)

        # Calculate sum
        sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True)
        global_sum_fc7 = hvd.allreduce(sum_fc7, average=False)

        # Calculate prob
        _fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7)

        # Calculate loss
        tmp = _fc7_grad * _one_hot
        tmp = nd.sum(tmp, axis=1, keepdims=True)
        tmp = self.get_ndarray2(self.gpu, 'ctx_loss', tmp)
        tmp = hvd.allreduce(tmp, average=False)
        global_loss = -nd.mean(nd.log(tmp + 1e-30))

        # Calculate fc7 grad
        _fc7_grad = _fc7_grad - _one_hot

        # Backward
        _fc7.backward(out_grad=_fc7_grad)

        # Update center
        _weight_grad = memory_bank.weight.grad
        self.memory_optimizer.update(weight=memory_bank.weight,
                                     grad=_weight_grad,
                                     state=memory_bank.weight_mom,
                                     learning_rate=self.memory_lr)

        return _data.grad, global_loss
示例#16
0
    def backward_sample(self, total_feature, label):
        this_rank_classes = int(self.memory_bank.num_sample)
        local_index, unique_sorted_global_label = self.memory_bank.sample(
            label)

        # Get local index
        _mapping_dict = {}
        local_sampled_class = local_index + self.rank * self.memory_bank.num_local
        global_label_set = set(unique_sorted_global_label)
        for idx, absolute_label in enumerate(local_sampled_class):
            if absolute_label in global_label_set:
                _mapping_dict[
                    absolute_label] = idx + self.rank * self.memory_bank.num_sample

        label_list = list(label.asnumpy())
        mapping_label = []
        for i in range(len(label_list)):
            absolute_label = label_list[i]
            if absolute_label in _mapping_dict.keys():
                mapping_label.append(_mapping_dict[absolute_label])
            else:
                mapping_label.append(-1)

        mapping_label = nd.array(mapping_label, dtype=np.int32)

        # Get weight
        local_index = nd.array(local_index)
        local_index = self.get_ndarray2(self.gpu, "local_index", local_index)
        sample_weight, sample_weight_mom = self.memory_bank.get(local_index)

        # Sync to gpu
        if self.memory_bank.gpu:
            _data = self.get_ndarray2(self.gpu, "data_%d" % self.rank,
                                      total_feature)
            _weight = self.get_ndarray2(self.gpu, 'weight_%d' % self.rank,
                                        sample_weight)
            _weight_mom = self.get_ndarray2(self.gpu,
                                            'weight_mom_%d' % self.rank,
                                            sample_weight_mom)
        else:
            _data = self.get_ndarray2(self.gpu, "data_%d" % self.rank,
                                      total_feature)
            _weight = self.get_ndarray2(self.gpu, 'weight_%d' % self.rank,
                                        sample_weight)
            _weight_mom = self.get_ndarray2(self.gpu,
                                            'weight_mom_%d' % self.rank,
                                            sample_weight_mom)

        # Attach grad
        _data.attach_grad()
        _weight.attach_grad()

        # Convert label
        _label = self.get_ndarray2(self.gpu, 'mapping_label_%d' % self.rank,
                                   mapping_label)
        _label = _label - int(self.rank * self.memory_bank.num_sample)
        _fc7, _one_hot = self.fc7_model.forward(_data,
                                                _weight,
                                                mapping_label=_label,
                                                depth=this_rank_classes)

        # Sync max
        max_fc7 = nd.max(_fc7, axis=1, keepdims=True)
        max_fc7 = nd.reshape(max_fc7, -1)

        total_max_fc7 = self.get_ndarray(context=self.gpu,
                                         name='total_max_fc7',
                                         shape=(max_fc7.shape[0], self.size),
                                         dtype='float32')
        total_max_fc7[:] = 0
        total_max_fc7[:, self.rank] = max_fc7
        hvd.allreduce_(total_max_fc7, average=False)

        global_max_fc7 = self.get_ndarray(context=self.gpu,
                                          name='global_max_fc7',
                                          shape=(max_fc7.shape[0], 1),
                                          dtype='float32')
        nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7)

        # Calculate exp(logits)
        _fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7)
        _fc7_grad = nd.exp(_fc7_grad)

        # Calculate sum
        sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True)
        global_sum_fc7 = hvd.allreduce(sum_fc7, average=False)

        # Calculate grad
        _fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7)

        # Calculate loss
        tmp = _fc7_grad * _one_hot
        tmp = nd.sum(tmp, axis=1, keepdims=True)
        tmp = self.get_ndarray2(self.gpu, 'ctx_loss', tmp)
        tmp = hvd.allreduce(tmp, average=False)
        global_loss = -nd.mean(nd.log(tmp + 1e-30))

        _fc7_grad = _fc7_grad - _one_hot

        # Backward
        _fc7.backward(out_grad=_fc7_grad)

        # Update center
        _weight_grad = _weight.grad
        self.memory_optimizer.update(weight=_weight,
                                     grad=_weight_grad,
                                     state=_weight_mom,
                                     learning_rate=self.memory_lr)
        if self.memory_bank.gpu:
            self.memory_bank.set(index=local_index,
                                 updated_weight=_weight,
                                 updated_weight_mom=_weight_mom)
        else:
            self.memory_bank.set(index=local_index,
                                 updated_weight=self.get_ndarray2(
                                     mx.cpu(), "cpu_weight_%d" % self.rank,
                                     _weight),
                                 updated_weight_mom=self.get_ndarray2(
                                     mx.cpu(), "cpu_weight_mom_%d" % self.rank,
                                     _weight_mom))
        return _data.grad, global_loss
示例#17
0
def train():
    """Training function."""
    segment = 'train'  #if not args.debug else 'dev'
    log.info('Loading %s data...', segment)
    if version_2:
        train_data = SQuAD(segment, version='2.0')
    else:
        train_data = SQuAD(segment, version='1.1')
    if args.debug:
        sampled_data = [train_data[i] for i in range(0, 10000)]
        train_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in Train data:{}'.format(len(train_data)))
    train_data_transform = preprocess_dataset(
        tokenizer,
        train_data,
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
        max_query_length=max_query_length,
        input_features=True)

    log.info('The number of examples after preprocessing:{}'.format(
        len(train_data_transform)))

    sampler = nlp.data.SplitSampler(len(train_data_transform),
                                    num_parts=size,
                                    part_index=rank,
                                    even_size=True)
    num_train_examples = len(sampler)
    train_dataloader = mx.gluon.data.DataLoader(train_data_transform,
                                                batchify_fn=batchify_fn,
                                                batch_size=batch_size,
                                                num_workers=4,
                                                sampler=sampler)

    log.info('Start Training')

    optimizer_params = {'learning_rate': lr}
    param_dict = net.collect_params()
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    if args.dtype == 'float16':
        amp.init_trainer(trainer)

    step_size = batch_size * accumulate if accumulate else batch_size
    num_train_steps = int(num_train_examples / step_size * args.epochs)
    if args.training_steps:
        num_train_steps = args.training_steps

    num_warmup_steps = int(num_train_steps * warmup_ratio)

    def set_new_lr(step_num, batch_id):
        """set new learning rate"""
        # set grad to zero for gradient accumulation
        if accumulate:
            if batch_id % accumulate == 0:
                step_num += 1
        else:
            step_num += 1
        # learning rate schedule
        # Notice that this learning rate scheduler is adapted from traditional linear learning
        # rate scheduler where step_num >= num_warmup_steps, new_lr = 1 - step_num/num_train_steps
        if step_num < num_warmup_steps:
            new_lr = lr * step_num / num_warmup_steps
        else:
            offset = (step_num - num_warmup_steps) * lr / \
                (num_train_steps - num_warmup_steps)
            new_lr = lr - offset
        trainer.set_learning_rate(new_lr)
        return step_num

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Collect differentiable parameters
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Set grad_req if gradient accumulation is required
    if accumulate:
        for p in params:
            p.grad_req = 'add'
    net.collect_params().zero_grad()

    epoch_tic = time.time()

    total_num = 0
    log_num = 0
    batch_id = 0
    step_loss = 0.0
    tic = time.time()
    step_num = 0

    tic = time.time()
    while step_num < num_train_steps:
        for _, data in enumerate(train_dataloader):
            # set new lr
            step_num = set_new_lr(step_num, batch_id)
            # forward and backward
            _, inputs, token_types, valid_length, start_label, end_label = data
            num_labels = len(inputs)
            log_num += num_labels
            total_num += num_labels

            with mx.autograd.record():
                out = net(inputs.as_in_context(ctx),
                          token_types.as_in_context(ctx),
                          valid_length.as_in_context(ctx).astype('float32'))

                loss = loss_function(out, [
                    start_label.as_in_context(ctx).astype('float32'),
                    end_label.as_in_context(ctx).astype('float32')
                ]).sum() / num_labels

                if accumulate:
                    loss = loss / accumulate
                if args.dtype == 'float16':
                    with amp.scale_loss(loss, trainer) as l:
                        mx.autograd.backward(l)
                        norm_clip = 1.0 * size * trainer._amp_loss_scaler.loss_scale
                else:
                    mx.autograd.backward(loss)
                    norm_clip = 1.0 * size

            # update
            if not accumulate or (batch_id + 1) % accumulate == 0:
                trainer.allreduce_grads()
                nlp.utils.clip_grad_global_norm(params, norm_clip)
                trainer.update(1)
                if accumulate:
                    param_dict.zero_grad()

            if args.comm_backend == 'horovod':
                step_loss += hvd.allreduce(loss, average=True).asscalar()
            else:
                step_loss += loss.asscalar()

            if (batch_id + 1) % log_interval == 0:
                toc = time.time()
                log.info('Batch: {}/{}, Loss={:.4f}, lr={:.7f} '
                         'Thoughput={:.2f} samples/s'.format(
                             batch_id % len(train_dataloader),
                             len(train_dataloader), step_loss / log_interval,
                             trainer.learning_rate, log_num / (toc - tic)))
                tic = time.time()
                step_loss = 0.0
                log_num = 0

            if step_num >= num_train_steps:
                break
            batch_id += 1

        log.info('Finish training step: %d', step_num)
        epoch_toc = time.time()
        log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
            epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))

    if rank == 0:
        net.save_parameters(os.path.join(output_dir, 'net.params'))
示例#18
0
        lr_step_epochs='30,60,80',
        dtype='float32')
    args = parser.parse_args()

    if 'horovod' in args.kv_store:
        # initialize Horovod with mpi4py comm
        hvd.init(mpiwrapper.get_comm())
        args.gpus = _get_gpu(args.gpus)
        kv = None
        local_rank = hvd.local_rank()

        # dummy Horovod ops to initialize resources
        ctx = mx.gpu(local_rank)
        tensor1 = mx.nd.zeros(shape=(1), dtype='float16', ctx=ctx)
        tensor2 = mx.nd.zeros(shape=(1), dtype='float32', ctx=ctx)
        summed1 = hvd.allreduce(tensor1, average=False)
        summed2 = hvd.allreduce(tensor2, average=False)

    framework = 'MxNet NGC {}'.format(os.environ["NVIDIA_MXNET_VERSION"])

    #mlperf_submission_log(
    #    benchmark=mlperf_constants.RESNET,
    #    framework=framework,
    #)

    # Load network
    from importlib import import_module
    net = import_module('symbols.' + args.network)

    # Initialize seed + random number generators
    if args.seed is None:
    def pushpull(self, key, value, out=None, priority=0):
        """ Performs allreduce on a single tensor or a list of tensor objects

        This function performs in-place summation of the input tensor over all the processes.

        The name `pushpull` is a generic term. In Horovod, its action is implemented via
        ring allreduce. Each operation is identified by the 'key'; if `key` is not provided, an
        incremented auto-generated name is used. The tensor type and shape must be
        the same on all processes for a given name. The reduction will not start until all processes
        are ready to send and receive the tensor.

        Parameters
        ----------
        key : str, int, or sequence of str or int
            Keys used to uniquely tag an operation.

        value : NDArray
            Tensor value on one process to be summed. If `out` is not specified, the `value` will
            be modified in-place

        out: NDArray
            Output tensor after allreduce. If not specified, the input tensor `value` will be
            modified in-place.

        priority : int, optional
            The priority of the operation.
            Higher priority operations are likely to be executed before other actions.

        Examples
        --------
        >>> # perform in-place allreduce on tensor a
        >>> shape = (2, 3)
        >>> nworker = kv.num_workers # assume there are 8 processes
        >>> a = mx.nd.ones(shape)
        >>> kv.pushpull('1', a)
        >>> print(a.asnumpy())
        [[ 8.  8.  8.]
        [ 8.  8.  8.]]

        >>> # perform allreduce on tensor a and output to b
        >>> a = mx.nd.ones(shape)
        >>> kv.pushpull('2', a, out=b)
        >>> print(b.asnumpy())
        [[ 8.  8.  8.]
        [ 8.  8.  8.]]
        """
        import horovod.mxnet as hvd

        if out is None:
            value = value if isinstance(value, list) else [value]
            for v in value:
                hvd.allreduce_(v,
                               average=False,
                               name=str(key),
                               priority=priority)
        else:
            out = out if isinstance(out, list) else [out]
            value = value if isinstance(value, list) else [value]
            for o, v in zip(out, value):
                o[:] = hvd.allreduce(v,
                                     average=False,
                                     name=str(key),
                                     priority=priority)
示例#20
0
def train(args):
    _, num_parts, rank, local_rank, _, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    if args.comm_backend == 'horovod':
        logging_config(
            args.save_dir,
            name=f'train_transformer_rank{rank}_local{local_rank}_{num_parts}',
            console=(rank == 0))
        logging.info(args)
    else:
        logging_config(args.save_dir, name='train_transformer', console=True)
        logging.info(args)
    use_amp = args.fp16
    if use_amp:
        from mxnet import amp
    src_tokenizer = create_tokenizer(args.src_tokenizer,
                                     args.src_subword_model_path,
                                     args.src_vocab_path)
    tgt_tokenizer = create_tokenizer(args.tgt_tokenizer,
                                     args.tgt_subword_model_path,
                                     args.tgt_vocab_path)
    base_tgt_tokenizer = MosesTokenizer(args.tgt_lang)
    src_vocab = src_tokenizer.vocab
    tgt_vocab = tgt_tokenizer.vocab
    train_src_data, train_tgt_data = load_dataset_with_cache(
        args.train_src_corpus,
        args.train_tgt_corpus,
        src_tokenizer,
        tgt_tokenizer,
        args.overwrite_cache,
        local_rank,
        max_src_length=args.max_src_length,
        max_tgt_length=args.max_tgt_length,
        pretokenized=not args.tokenize)
    dev_src_data, dev_tgt_data = load_dataset_with_cache(
        args.dev_src_corpus,
        args.dev_tgt_corpus,
        src_tokenizer,
        tgt_tokenizer,
        args.overwrite_cache,
        local_rank,
        pretokenized=not args.tokenize)
    tgt_detok_sentences = []
    tgt_raw_sentences = []
    with open(args.dev_tgt_corpus, 'r') as in_f:
        for line in in_f:
            tgt_detok_sentences.append(
                base_tgt_tokenizer.decode(
                    tgt_tokenizer.decode(line.split()).split()))
    with open(args.dev_tgt_raw_corpus, 'r') as in_f:
        for line in in_f:
            tgt_raw_sentences.append(line.strip())
    data_train = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))
    ])
    val_samples = [
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))
    ]
    if args.comm_backend == 'horovod':
        slice_begin = rank * (len(val_samples) // num_parts)
        slice_end = min((rank + 1) * (len(val_samples) // num_parts),
                        len(val_samples))
        data_val = gluon.data.SimpleDataset(val_samples[slice_begin:slice_end])
    else:
        data_val = gluon.data.SimpleDataset(val_samples)
    # Construct the model + loss function
    if args.cfg.endswith('.yml'):
        cfg = TransformerModel.get_cfg().clone_merge(args.cfg)
    else:
        cfg = TransformerModel.get_cfg(args.cfg)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = len(src_vocab)
    cfg.MODEL.tgt_vocab_size = len(tgt_vocab)
    cfg.MODEL.layout = 'TN'
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l)
    model.hybridize()
    for v in model.collect_params().values():
        if v.grad_req != 'null':
            v.grad_req = 'add'
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    param_dict = deduplicate_param_dict(model.collect_params())

    inference_model = TransformerInference(model=model)
    inference_model.hybridize()
    if local_rank == 0:
        logging.info(model)
    with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f:
        cfg_f.write(cfg.dump())
    label_smooth_loss = LabelSmoothCrossEntropyLoss(
        num_labels=len(tgt_vocab),
        alpha=args.label_smooth_alpha,
        from_logits=False)
    label_smooth_loss.hybridize()

    # Construct the beam search sampler
    scorer = BeamSearchScorer(alpha=args.lp_alpha,
                              K=args.lp_k,
                              from_logits=False)
    beam_search_sampler = BeamSearchSampler(beam_size=args.beam_size,
                                            decoder=inference_model,
                                            vocab_size=len(tgt_vocab),
                                            eos_id=tgt_vocab.eos_id,
                                            scorer=scorer,
                                            stochastic=False,
                                            max_length_a=args.max_length_a,
                                            max_length_b=args.max_length_b)

    logging.info(beam_search_sampler)
    if args.comm_backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    # Construct the trainer
    if args.lr is None:
        base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt(
            args.warmup_steps)
    else:
        base_lr = args.lr
    lr_scheduler = InverseSquareRootScheduler(
        warmup_steps=args.warmup_steps,
        base_lr=base_lr,
        warmup_init_lr=args.warmup_init_lr)
    optimizer_params = {
        'learning_rate': args.lr,
        'beta1': 0.9,
        'beta2': 0.997,
        'epsilon': 1e-9,
        'lr_scheduler': lr_scheduler,
        'wd': args.wd
    }
    user_provided_ptimizer_params = json.loads(args.optimizer_params)
    optimizer_params.update(user_provided_ptimizer_params)

    if args.fp16:
        optimizer_params.update({'multi_precision': True})
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = gluon.Trainer(param_dict,
                                args.optimizer,
                                optimizer_params,
                                update_on_kvstore=False)
    # Load Data
    if args.sampler == 'BoundedBudgetSampler':
        train_batch_sampler = BoundedBudgetSampler(
            lengths=[(ele[2], ele[3]) for ele in data_train],
            max_num_tokens=args.max_num_tokens,
            max_num_sentences=args.max_num_sentences,
            shuffle=True,
            seed=args.seed)
    elif args.sampler == 'FixedBucketSampler':
        if args.comm_backend == 'horovod':
            raise NotImplementedError(
                'FixedBucketSampler does not support horovod at present')

        if args.bucket_scheme == 'constant':
            bucket_scheme = ConstWidthBucket()
        elif args.bucket_scheme == 'linear':
            bucket_scheme = LinearWidthBucket()
        elif args.bucket_scheme == 'exp':
            bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
        else:
            raise NotImplementedError
        # TODO(sxjscience) Support auto-bucket-size tuning
        train_batch_sampler = FixedBucketSampler(lengths=[
            (ele[2], ele[3]) for ele in data_train
        ],
                                                 batch_size=args.batch_size,
                                                 num_buckets=args.num_buckets,
                                                 ratio=args.bucket_ratio,
                                                 shuffle=True,
                                                 use_average_length=True,
                                                 bucket_scheme=bucket_scheme,
                                                 seed=args.seed)
    else:
        raise NotImplementedError

    num_updates_per_epoch = int(
        math.ceil(
            len(train_batch_sampler) /
            (num_parts * len(ctx_l) * args.num_accumulated)))
    # Convert the batch sampler to multiple shards
    if num_parts > 1:
        train_batch_sampler = ShardedIterator(train_batch_sampler,
                                              num_parts=num_parts,
                                              part_index=rank,
                                              even_size=True,
                                              seed=args.seed + 1000 * rank)

    logging.info(train_batch_sampler)

    batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(),
                           bf.Stack())
    train_data_loader = gluon.data.DataLoader(
        data_train,
        batch_sampler=train_batch_sampler,
        batchify_fn=batchify_fn,
        num_workers=0)
    val_data_loader = gluon.data.DataLoader(data_val,
                                            batch_size=args.val_batch_size,
                                            batchify_fn=batchify_fn,
                                            num_workers=0,
                                            shuffle=False)
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    model_averager = AverageSGDTracker(param_dict)
    log_start_time = time.time()
    num_params, num_fixed_params = None, None

    # TODO(sxjscience) Add a log metric class
    log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
    # Maintain the denominator of the loss.
    log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
    log_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l]
    log_tgt_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l]
    log_avg_grad_norm = 0
    log_iter_num = 0

    if local_rank == 0:
        writer = SummaryWriter(
            logdir=os.path.join(args.save_dir, 'tensorboard'))
    if use_amp:
        amp.init_trainer(trainer)
    train_multi_data_loader = grouper(repeat(train_data_loader), len(ctx_l))
    # when args.epochs < 0, the model will keep training
    if args.epochs < 0:
        if args.max_update > 0:
            total_train_iters = args.max_update
            if args.num_averages > 0:
                assert args.num_averages <= total_train_iters // args.save_iterval_update
                avg_start_iter = (
                    total_train_iters // args.save_iterval_update -
                    args.num_averages) * args.save_iterval_update
            else:
                avg_start_iter = -1
        else:
            total_train_iters = np.inf
            avg_start_iter = -1
    else:
        total_train_iters = args.epochs * num_updates_per_epoch
        if args.num_averages > 0:
            assert args.num_averages <= args.epochs
            avg_start_iter = (args.epochs -
                              args.num_average) * num_updates_per_epoch
        else:
            avg_start_iter = -1

    # Here, we are manually setting up the scale to 1.0 because
    # in horovod, the scale can be the number of workers:
    # See the code here: https://github.com/horovod/horovod/blob/125115583b7029196e2ec530decd4209459d5479/horovod/mxnet/__init__.py#L141
    # Since we will need to use the dynamic scaling in amp, we will manually call amp.unscale().
    # A scale that is larger than 1.0 can be problematic in this case.
    trainer._scale = 1.0
    if args.max_num_tokens > 0:
        const_scale = args.max_num_tokens
    else:
        const_scale = 100

    train_start_time = time.time()

    for train_iter in range(total_train_iters):
        model.zero_grad()
        loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
        for i in range(args.num_accumulated):
            loss_l = []
            sample_data_l = next(train_multi_data_loader)
            for j, (sample_data, ctx) in enumerate(zip(sample_data_l, ctx_l)):
                src_token_ids, tgt_token_ids, src_valid_length,\
                tgt_valid_length, sample_ids = sample_data
                src_token_ids = src_token_ids.as_in_ctx(ctx)
                tgt_token_ids = tgt_token_ids.as_in_ctx(ctx)
                src_valid_length = src_valid_length.as_in_ctx(ctx)
                tgt_valid_length = tgt_valid_length.as_in_ctx(ctx)
                src_wc, tgt_wc, bs = src_valid_length.sum(), \
                                     tgt_valid_length.sum(), src_token_ids.shape[0]
                log_wc_l[j] += src_wc + tgt_wc
                log_tgt_wc_l[j] += tgt_wc
                token_count = (tgt_valid_length - 1).sum()
                loss_denom_l[j] += token_count / const_scale
                log_avg_loss_denom_l[j] += token_count / const_scale
                with mx.autograd.record():
                    if model.layout == 'NT':
                        tgt_pred = model(src_token_ids, src_valid_length,
                                         tgt_token_ids[:, :-1],
                                         tgt_valid_length - 1)
                        tgt_labels = tgt_token_ids[:, 1:]
                        loss = label_smooth_loss(tgt_pred, tgt_labels)
                        loss = mx.npx.sequence_mask(
                            loss,
                            sequence_length=tgt_valid_length - 1,
                            use_sequence_length=True,
                            axis=1)
                        loss = loss.sum() / const_scale
                        loss_l.append(loss)
                    elif model.layout == 'TN':
                        tgt_pred = model(src_token_ids.T, src_valid_length,
                                         tgt_token_ids.T[:-1, :],
                                         tgt_valid_length - 1)
                        tgt_labels = tgt_token_ids.T[1:, :]
                        loss = label_smooth_loss(tgt_pred, tgt_labels)
                        loss = mx.npx.sequence_mask(
                            loss,
                            sequence_length=tgt_valid_length - 1,
                            use_sequence_length=True,
                            axis=0)
                        loss = loss.sum() / const_scale
                        loss_l.append(loss)
                log_avg_loss_l[j] += loss
            if use_amp:
                with mx.autograd.record():
                    with amp.scale_loss(loss_l, trainer) as amp_loss_l:
                        for loss in amp_loss_l:
                            loss.backward()
            else:
                with mx.autograd.record():
                    for loss in loss_l:
                        loss.backward()

        # Print the total number of parameters
        if local_rank == 0 and num_params is None:
            num_params, num_fixed_params = count_parameters(param_dict)
            logging.info(
                'Total Number of Parameters (not-fixed/fixed): {}/{}'.format(
                    num_params, num_fixed_params))
        # All-Reduce the gradient
        trainer.allreduce_grads()
        if args.comm_backend == 'horovod':
            # All-Reduce the loss denominator
            assert len(loss_denom_l) == 1
            loss_denom = hvd.allreduce(loss_denom_l[0],
                                       average=False).asnumpy()
        else:
            loss_denom = sum([ele.asnumpy() for ele in loss_denom_l])
        if use_amp:
            # We need to first unscale the gradient and then perform allreduce.
            grad_scale = trainer.amp_loss_scale * loss_denom
        else:
            grad_scale = loss_denom
        if args.max_grad_norm is not None:
            total_norm, ratio, is_finite\
                = clip_grad_global_norm(params, args.max_grad_norm * grad_scale)
            total_norm = total_norm / grad_scale
        else:
            total_norm = grad_global_norm(params)
            total_norm = total_norm / grad_scale
        log_avg_grad_norm += total_norm
        log_iter_num += 1

        trainer.update(loss_denom, ignore_stale_grad=True)

        if avg_start_iter > 0 and train_iter >= avg_start_iter:
            model_averager.step()

        if ((train_iter + 1) % args.log_interval == 0
                or train_iter + 1 == total_train_iters):
            if args.comm_backend == 'horovod':
                # Use allreduce to get the total number of tokens and loss
                log_wc = hvd.allreduce(log_wc_l[0], average=False).asnumpy()
                log_tgt_wc = hvd.allreduce(log_tgt_wc_l[0],
                                           average=False).asnumpy()
                log_avg_loss = hvd.allreduce(log_avg_loss_l[0] /
                                             log_avg_loss_denom_l[0],
                                             average=True)
                log_avg_loss = log_avg_loss.asnumpy()
            else:
                log_wc = sum([ele.asnumpy() for ele in log_wc_l])
                log_tgt_wc = sum([ele.asnumpy() for ele in log_tgt_wc_l])
                log_avg_loss =\
                    sum([log_avg_loss_l[i].asnumpy() / log_avg_loss_denom_l[i].asnumpy()
                         for i in range(len(log_avg_loss_l))]) / len(log_avg_loss_l)
            log_avg_grad_norm = log_avg_grad_norm / log_iter_num
            log_end_time = time.time()
            wps = log_wc / (log_end_time - log_start_time)
            epoch_id = train_iter // num_updates_per_epoch
            logging.info(
                '[Epoch {} Iter {}/{}, Overall {}/{}] loss={:.4f}, ppl={:.4f}, '
                'throughput={:.2f}K wps, total wc={:.2f}K, wpb={:.2f}K,'
                ' LR={}, gnorm={:.4f}, ETA={:.2f}h'.format(
                    epoch_id, train_iter % num_updates_per_epoch + 1,
                    num_updates_per_epoch,
                    train_iter + 1, total_train_iters, log_avg_loss,
                    np.exp(log_avg_loss), wps / 1000, log_wc / 1000,
                    log_tgt_wc / 1000 / log_iter_num, trainer.learning_rate,
                    log_avg_grad_norm,
                    (log_end_time - train_start_time) / (train_iter + 1) *
                    (total_train_iters - train_iter - 1) / 3600))
            if local_rank == 0:
                writer.add_scalar('throughput_wps', wps, train_iter)
                writer.add_scalar('train_loss', log_avg_loss, train_iter)
                writer.add_scalar('lr', trainer.learning_rate, train_iter)
                writer.add_scalar('grad_norm', log_avg_grad_norm, train_iter)
            # Reinitialize the log variables
            log_start_time = time.time()
            log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
            log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
            log_avg_grad_norm = 0
            log_iter_num = 0
            log_wc_l = [
                mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l
            ]
            log_tgt_wc_l = [
                mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l
            ]

        if (args.max_update > 0 and (train_iter + 1) % args.save_interval_update == 0) \
            or ((train_iter + 1) % num_updates_per_epoch == 0) \
            or train_iter + 1 == total_train_iters:
            epoch_id = (train_iter + 1) // num_updates_per_epoch
            if local_rank == 0:
                if args.max_update <= 0:
                    model.save_parameters(os.path.join(
                        args.save_dir, 'epoch{}.params'.format(epoch_id)),
                                          deduplicate=True)
                else:
                    model.save_parameters(os.path.join(
                        args.save_dir, 'iter{}.params'.format(train_iter + 1)),
                                          deduplicate=True)

            avg_val_loss, ntokens, pred_sentences, pred_lengths, sentence_ids\
                = validation(model, val_data_loader, inference_model, beam_search_sampler,
                             tgt_tokenizer, ctx_l)
            if args.comm_backend == 'horovod':
                flatten_pred_sentences = np.concatenate(pred_sentences, axis=0)
                all_val_loss = hvd.allgather(
                    mx.np.array([avg_val_loss * ntokens],
                                dtype=np.float32,
                                ctx=ctx_l[0]))
                all_ntokens = hvd.allgather(
                    mx.np.array([ntokens], dtype=np.int64, ctx=ctx_l[0]))
                flatten_pred_sentences = hvd.allgather(
                    mx.np.array(flatten_pred_sentences,
                                dtype=np.int32,
                                ctx=ctx_l[0]))
                pred_lengths = hvd.allgather(
                    mx.np.array(pred_lengths, dtype=np.int64, ctx=ctx_l[0]))
                sentence_ids = hvd.allgather(
                    mx.np.array(sentence_ids, dtype=np.int64, ctx=ctx_l[0]))
                avg_val_loss = all_val_loss.asnumpy().sum(
                ) / all_ntokens.asnumpy().sum()
                flatten_pred_sentences = flatten_pred_sentences.asnumpy()
                pred_lengths = pred_lengths.asnumpy()
                sentence_ids = sentence_ids.asnumpy()
                pred_sentences = [None for _ in range(len(sentence_ids))]
                ptr = 0
                assert sentence_ids.min() == 0 and sentence_ids.max(
                ) == len(sentence_ids) - 1
                for sentence_id, length in zip(sentence_ids, pred_lengths):
                    pred_sentences[sentence_id] = flatten_pred_sentences[ptr:(
                        ptr + length)]
                    ptr += length
            if local_rank == 0:
                # Perform detokenization
                pred_sentences_bpe_decode = []
                pred_sentences_raw = []
                for sentence in pred_sentences:
                    bpe_decode_sentence = tgt_tokenizer.decode(
                        sentence.tolist())
                    raw_sentence = base_tgt_tokenizer.decode(
                        bpe_decode_sentence.split())
                    pred_sentences_bpe_decode.append(bpe_decode_sentence)
                    pred_sentences_raw.append(raw_sentence)
                detok_sacrebleu_out = sacrebleu.corpus_bleu(
                    sys_stream=pred_sentences_bpe_decode,
                    ref_streams=[tgt_detok_sentences])
                raw_sacrebleu_out = sacrebleu.corpus_bleu(
                    sys_stream=pred_sentences_raw,
                    ref_streams=[tgt_raw_sentences])
                with open(
                        os.path.join(args.save_dir,
                                     f'epoch{epoch_id}_dev_prediction.txt'),
                        'w') as of:
                    for line in pred_sentences_raw:
                        of.write(line + '\n')
                logging.info(
                    '[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, '
                    'SacreBlEU={}, Detok SacreBLUE={}'.format(
                        epoch_id, train_iter, total_train_iters, avg_val_loss,
                        np.exp(avg_val_loss), raw_sacrebleu_out.score,
                        detok_sacrebleu_out.score))
                writer.add_scalar('valid_loss', avg_val_loss, train_iter)
                writer.add_scalar('valid_bleu', raw_sacrebleu_out.score,
                                  train_iter)

    if args.num_averages > 0:
        model_averager.copy_back(
            param_dict)  # TODO(sxjscience) Rewrite using update
        model.save_parameters(os.path.join(args.save_dir, 'average.params'),
                              deduplicate=True)