示例#1
0
def get_dataloader(data_set,
                   args,
                   dataset_type,
                   use_average_length=False,
                   num_shards=0,
                   num_workers=8):
    """Create data loaders for training/validation/test."""
    assert dataset_type in ['train', 'val', 'test']

    if args.bucket_scheme == 'constant':
        bucket_scheme = nlp.data.ConstWidthBucket()
    elif args.bucket_scheme == 'linear':
        bucket_scheme = nlp.data.LinearWidthBucket()
    elif args.bucket_scheme == 'exp':
        bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)
    else:
        raise NotImplementedError

    data_lengths = get_data_lengths(data_set)

    if dataset_type == 'train':
        train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                      btf.Stack(dtype='float32'),
                                      btf.Stack(dtype='float32'))

    else:
        data_lengths = list(map(lambda x: x[-1], data_lengths))
        test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                     btf.Stack(dtype='float32'),
                                     btf.Stack(dtype='float32'), btf.Stack())

    batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths,
                                                batch_size=(args.batch_size \
                                                            if dataset_type == 'train' \
                                                            else args.test_batch_size),
                                                num_buckets=args.num_buckets,
                                                ratio=args.bucket_ratio,
                                                shuffle=(dataset_type == 'train'),
                                                use_average_length=use_average_length,
                                                num_shards=num_shards,
                                                bucket_scheme=bucket_scheme)

    if dataset_type == 'train':
        logging.info('Train Batch Sampler:\n%s', batch_sampler.stats())
        data_loader = nlp.data.ShardedDataLoader(data_set,
                                                 batch_sampler=batch_sampler,
                                                 batchify_fn=train_batchify_fn,
                                                 num_workers=num_workers)
    else:
        if dataset_type == 'val':
            logging.info('Valid Batch Sampler:\n%s', batch_sampler.stats())
        else:
            logging.info('Test Batch Sampler:\n%s', batch_sampler.stats())

        data_loader = gluon.data.DataLoader(data_set,
                                            batch_sampler=batch_sampler,
                                            batchify_fn=test_batchify_fn,
                                            num_workers=num_workers)

    return data_loader
示例#2
0
def make_dataloader(data_train, data_val, data_test, args,
                    use_average_length=False, num_shards=0, num_workers=8):
    """Create data loaders for training/validation/test."""
    data_train_lengths = get_data_lengths(data_train)
    data_val_lengths = get_data_lengths(data_val)
    data_test_lengths = get_data_lengths(data_test)
    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                  btf.Stack(dtype='float32'), btf.Stack(dtype='float32'))
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                 btf.Stack(dtype='float32'), btf.Stack(dtype='float32'),
                                 btf.Stack())
    target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
    target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
    if args.bucket_scheme == 'constant':
        bucket_scheme = nlp.data.ConstWidthBucket()
    elif args.bucket_scheme == 'linear':
        bucket_scheme = nlp.data.LinearWidthBucket()
    elif args.bucket_scheme == 'exp':
        bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)
    else:
        raise NotImplementedError
    train_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_train_lengths,
                                                      batch_size=args.batch_size,
                                                      num_buckets=args.num_buckets,
                                                      ratio=args.bucket_ratio,
                                                      shuffle=True,
                                                      use_average_length=use_average_length,
                                                      num_shards=num_shards,
                                                      bucket_scheme=bucket_scheme)
    logging.info('Train Batch Sampler:\n%s', train_batch_sampler.stats())
    train_data_loader = nlp.data.ShardedDataLoader(data_train,
                                                   batch_sampler=train_batch_sampler,
                                                   batchify_fn=train_batchify_fn,
                                                   num_workers=num_workers)

    val_batch_sampler = nlp.data.FixedBucketSampler(lengths=target_val_lengths,
                                                    batch_size=args.test_batch_size,
                                                    num_buckets=args.num_buckets,
                                                    ratio=args.bucket_ratio,
                                                    shuffle=False,
                                                    use_average_length=use_average_length,
                                                    bucket_scheme=bucket_scheme)
    logging.info('Valid Batch Sampler:\n%s', val_batch_sampler.stats())
    val_data_loader = gluon.data.DataLoader(data_val,
                                            batch_sampler=val_batch_sampler,
                                            batchify_fn=test_batchify_fn,
                                            num_workers=num_workers)
    test_batch_sampler = nlp.data.FixedBucketSampler(lengths=target_test_lengths,
                                                     batch_size=args.test_batch_size,
                                                     num_buckets=args.num_buckets,
                                                     ratio=args.bucket_ratio,
                                                     shuffle=False,
                                                     use_average_length=use_average_length,
                                                     bucket_scheme=bucket_scheme)
    logging.info('Test Batch Sampler:\n%s', test_batch_sampler.stats())
    test_data_loader = gluon.data.DataLoader(data_test,
                                             batch_sampler=test_batch_sampler,
                                             batchify_fn=test_batchify_fn,
                                             num_workers=num_workers)
    return train_data_loader, val_data_loader, test_data_loader
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     # Turn Dataset into Dataloader
     # (GPT-2 does not have a padding token; but the padding token shouldn't matter anyways)
     self._batchify_fn = btf.Tuple(
         btf.Stack(dtype='int32'),
         btf.Pad(pad_val=np.iinfo(np.int32).max, dtype='int32'),
         btf.Stack(dtype='int32'))
 def __init__(self, *args, **kwargs):
     self._wwm = kwargs.pop('wwm') if 'wwm' in kwargs else False
     super().__init__(*args, **kwargs)
     # Turn Dataset into Dataloader
     self._batchify_fn = btf.Tuple(
         btf.Stack(dtype='int32'),
         btf.Pad(pad_val=np.iinfo(np.int32).max, dtype='int32'),
         btf.Stack(dtype='float32'))
示例#5
0
def test_named_tuple():
    a = ([1, 2, 3, 4], 0)
    b = ([5, 7], 1)
    c = ([1, 2, 3, 4, 5, 6, 7], 0)
    batchify_fn = batchify.NamedTuple([('data', batchify.Pad()),
                                       ('label', batchify.Stack())],
                                      name='SomeName')
    sample = batchify_fn([a, b, c])
    gt_data = batchify.Pad()([a[0], b[0], c[0]])
    gt_label = batchify.Stack()([a[1], b[1], c[1]])
    assert_allclose(sample.data.asnumpy(), gt_data.asnumpy())
    assert_allclose(sample.label.asnumpy(), gt_label.asnumpy())
    assert type(sample).__name__ == 'SomeName'
示例#6
0
    def __init__(self, *args, **kwargs):
        self._wwm = kwargs.pop('wwm') if 'wwm' in kwargs else False
        super().__init__(*args, **kwargs)
        # Turn Dataset into Dataloader
        self._batchify_fn = btf.Tuple(btf.Stack(dtype='int32'), btf.Pad(pad_val=np.iinfo(np.int32).max, dtype='int32'),
                              btf.Stack(dtype='float32'), btf.Stack(dtype='float32'))
        self._trainer = mx.gluon.Trainer(self._model.collect_params(), 'adam',
                           {'learning_rate': 1e-5, 'epsilon': 1e-9}, update_on_kvstore=False)
        self._loss = mx.gluon.loss.L2Loss()
        self._loss.hybridize(static_alloc=True)
        self._params = [p for p in self._model.collect_params().values() if p.grad_req != 'null']

        self._max_length = 384
示例#7
0
def test_stack_batchify(odtype, idtype, pass_dtype):
    dat = [np.random.randint(5, size=(10, )).astype(idtype) for _ in range(10)]
    batchify_fn = batchify.Stack(dtype=odtype if pass_dtype else None)
    batchify_out = batchify_fn(dat).asnumpy()
    npy_out = np.array(dat)
    assert_allclose(batchify_out, npy_out)
    assert batchify_out.dtype == npy_out.dtype if not pass_dtype else odtype
示例#8
0
def test_stack_batchify():
    for dtype in [
            np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64
    ]:
        dat = [
            np.random.randint(5, size=(10, )).astype(dtype) for _ in range(10)
        ]
        for _dtype in [None, dtype]:
            batchify_fn = batchify.Stack(dtype=_dtype)
            batchify_out = batchify_fn(dat).asnumpy()
            npy_out = np.array(dat)
            assert_allclose(batchify_out, npy_out)
            assert batchify_out.dtype == npy_out.dtype
        for _dtype in [np.float16, np.float32]:
            batchify_fn = batchify.Stack(dtype=_dtype)
            batchify_out = batchify_fn(dat).asnumpy()
            assert batchify_out.dtype == _dtype
示例#9
0
def test_dict():
    a = {'data': [1, 2, 3, 4], 'label': 0}
    b = {'data': [5, 7], 'label': 1}
    c = {'data': [1, 2, 3, 4, 5, 6, 7], 'label': 0}
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.Dict(
            [batchify.Pad(pad_val=0),
             batchify.Stack()])
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, {'a': 1, 'b': 2})
    batchify_fn = batchify.Dict({
        'data': batchify.Pad(pad_val=0),
        'label': batchify.Stack()
    })
    sample = batchify_fn([a, b, c])
    gt_data = batchify.Pad(pad_val=0)([a['data'], b['data'], c['data']])
    gt_label = batchify.Stack()([a['label'], b['label'], c['label']])
    assert isinstance(sample, dict)
    assert_allclose(sample['data'].asnumpy(), gt_data.asnumpy())
    assert_allclose(sample['label'].asnumpy(), gt_label.asnumpy())
示例#10
0
    def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length):
        """

        Parameters
        ----------
        tokenizer
            The tokenizer
        doc_stride
            The stride to chunk the document
        max_seq_length
            Maximum length of the merged data
        max_query_length
            Maximum query length
        """
        self._tokenizer = tokenizer
        self._doc_stride = doc_stride
        self._max_seq_length = max_seq_length
        self._max_query_length = max_query_length

        vocab = tokenizer.vocab
        self.pad_id = vocab.pad_id
        # For roberta model, taking sepecial token <s> as [CLS] and </s> as [SEP]
        self.cls_id = vocab.bos_id if 'cls_token' not in vocab.special_token_keys else vocab.cls_id
        self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id

        # TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality.
        self.ChunkFeature = collections.namedtuple('ChunkFeature',
                                              ['qas_id',
                                               'data',
                                               'valid_length',
                                               'segment_ids',
                                               'masks',
                                               'is_impossible',
                                               'gt_start',
                                               'gt_end',
                                               'context_offset',
                                               'chunk_start',
                                               'chunk_length'])
        self.BatchifyFunction = bf.NamedTuple(self.ChunkFeature,
                                         {'qas_id': bf.List(),
                                          'data': bf.Pad(val=self.pad_id),
                                          'valid_length': bf.Stack(),
                                          'segment_ids': bf.Pad(),
                                          'masks': bf.Pad(val=1),
                                          'is_impossible': bf.Stack(),
                                          'gt_start': bf.Stack(),
                                          'gt_end': bf.Stack(),
                                          'context_offset': bf.Stack(),
                                          'chunk_start': bf.Stack(),
                                          'chunk_length': bf.Stack()})
def prepare_data_loader(args, dataset, vocab, test=False):
    """
    Read data and build data loader.
    """
    # Preprocess
    dataset = dataset.transform(lambda s1, s2, label:
                                (vocab(s1), vocab(s2), label),
                                lazy=False)

    # Batching
    batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='int32'))
    data_lengths = [max(len(d[0]), len(d[1])) for d in dataset]
    batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths,
                                                batch_size=args.batch_size,
                                                shuffle=(not test))
    data_loader = gluon.data.DataLoader(dataset=dataset,
                                        batch_sampler=batch_sampler,
                                        batchify_fn=batchify_fn)
    return data_loader
示例#12
0
def test_named_tuple():
    a = MyNamedTuple([1, 2, 3, 4], 0)
    b = MyNamedTuple([5, 7], 1)
    c = MyNamedTuple([1, 2, 3, 4, 5, 6, 7], 0)
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(
            MyNamedTuple, {
                'data0': batchify.Pad(pad_val=0),
                'label': batchify.Stack()
            })
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(
            MyNamedTuple,
            [batchify.Pad(pad_val=0),
             batchify.Stack(),
             batchify.Stack()])
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple,
                                                (batchify.Pad(pad_val=0), ))
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, [1, 2])
    for batchify_fn in [
            batchify.NamedTuple(MyNamedTuple, {
                'data': batchify.Pad(pad_val=0),
                'label': batchify.Stack()
            }),
            batchify.NamedTuple(MyNamedTuple,
                                [batchify.Pad(pad_val=0),
                                 batchify.Stack()]),
            batchify.NamedTuple(MyNamedTuple,
                                (batchify.Pad(pad_val=0), batchify.Stack()))
    ]:
        sample = batchify_fn([a, b, c])
        gt_data = batchify.Pad(pad_val=0)([a[0], b[0], c[0]])
        gt_label = batchify.Stack()([a[1], b[1], c[1]])
        assert isinstance(sample, MyNamedTuple)
        assert_allclose(sample.data.asnumpy(), gt_data.asnumpy())
        assert_allclose(sample.label.asnumpy(), gt_label.asnumpy())
        with pytest.raises(ValueError):
            batchify_fn([1, 2, 3])
示例#13
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer, {
        'learning_rate': args.lr,
        'beta2': 0.98,
        'epsilon': 1e-9
    })

    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(),
                                  btf.Stack())
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(),
                                 btf.Stack(), btf.Stack())
    target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
    target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
    train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths,
                                             batch_size=args.batch_size,
                                             num_buckets=args.num_buckets,
                                             ratio=args.bucket_ratio,
                                             shuffle=True,
                                             use_average_length=True)
    logging.info('Train Batch Sampler:\n{}'.format(
        train_batch_sampler.stats()))
    train_data_loader = DataLoader(data_train,
                                   batch_sampler=train_batch_sampler,
                                   batchify_fn=train_batchify_fn,
                                   num_workers=8)

    val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths,
                                           batch_size=args.test_batch_size,
                                           num_buckets=args.num_buckets,
                                           ratio=args.bucket_ratio,
                                           shuffle=False,
                                           use_average_length=True)
    logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
    val_data_loader = DataLoader(data_val,
                                 batch_sampler=val_batch_sampler,
                                 batchify_fn=test_batchify_fn,
                                 num_workers=8)
    test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths,
                                            batch_size=args.test_batch_size,
                                            num_buckets=args.num_buckets,
                                            ratio=args.bucket_ratio,
                                            shuffle=False,
                                            use_average_length=True)
    logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
    test_data_loader = DataLoader(data_test,
                                  batch_sampler=test_batch_sampler,
                                  batchify_fn=test_batchify_fn,
                                  num_workers=8)
    best_valid_bleu = 0.0
    step_num = 0
    warmup_steps = args.warmup_steps
    grad_interval = args.num_accumulated
    model.collect_params().setattr('grad_req', 'add')
    average_start = (len(train_data_loader) //
                     grad_interval) * (args.epochs - args.average_start)
    average_param_dict = None
    model.collect_params().zero_grad()
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_wc = 0
        loss_denom = 0
        step_loss = 0
        log_start_time = time.time()
        for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length) \
                in enumerate(train_data_loader):
            if batch_id % grad_interval == 0:
                step_num += 1
                new_lr = args.lr / math.sqrt(args.num_units) \
                         * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5))
                trainer.set_learning_rate(new_lr)
            # logging.info(src_seq.context) Context suddenly becomes GPU.
            src_wc = src_valid_length.sum().asscalar()
            tgt_wc = tgt_valid_length.sum().asscalar()
            loss_denom += tgt_wc - tgt_valid_length.shape[0]
            if src_seq.shape[0] > len(ctx):
                src_seq_list, tgt_seq_list, src_valid_length_list, tgt_valid_length_list \
                    = [gluon.utils.split_and_load(seq, ctx, batch_axis=0, even_split=False)
                       for seq in [src_seq, tgt_seq, src_valid_length, tgt_valid_length]]
            else:
                src_seq_list = [src_seq.as_in_context(ctx[0])]
                tgt_seq_list = [tgt_seq.as_in_context(ctx[0])]
                src_valid_length_list = [
                    src_valid_length.as_in_context(ctx[0])
                ]
                tgt_valid_length_list = [
                    tgt_valid_length.as_in_context(ctx[0])
                ]

            Ls = []
            with mx.autograd.record():
                for src_seq, tgt_seq, src_valid_length, tgt_valid_length \
                        in zip(src_seq_list, tgt_seq_list,
                               src_valid_length_list, tgt_valid_length_list):
                    out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length,
                                   tgt_valid_length - 1)
                    smoothed_label = label_smoothing(tgt_seq[:, 1:])
                    ls = loss_function(out, smoothed_label,
                                       tgt_valid_length - 1).sum()
                    Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size)
            for L in Ls:
                L.backward()
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                if average_param_dict is None:
                    average_param_dict = {
                        k: v.data(ctx[0]).copy()
                        for k, v in model.collect_params().items()
                    }
                trainer.step(float(loss_denom) / args.batch_size)
                param_dict = model.collect_params()
                param_dict.zero_grad()
                if step_num > average_start:
                    alpha = 1. / max(1, step_num - average_start)
                    for name, average_param in average_param_dict.items():
                        average_param[:] += alpha * (
                            param_dict[name].data(ctx[0]) - average_param)
            step_loss += sum([L.asscalar() for L in Ls])
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                log_avg_loss += step_loss / loss_denom * args.batch_size
                loss_denom = 0
                step_loss = 0
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % (args.log_interval * grad_interval) == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'.format(
                                 epoch_id, batch_id + 1,
                                 len(train_data_loader),
                                 log_avg_loss / args.log_interval,
                                 np.exp(log_avg_loss / args.log_interval),
                                 wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_wc = 0
        mx.nd.waitall()
        valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
        valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                    valid_translation_out,
                                                    bpe=True,
                                                    split_compound_word=True)
        logging.info(
            '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
            .format(epoch_id, valid_loss, np.exp(valid_loss),
                    valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
        test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                                   test_translation_out,
                                                   bpe=True,
                                                   split_compound_word=True)
        logging.info(
            '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
            format(epoch_id, test_loss, np.exp(test_loss),
                   test_bleu_score * 100))
        write_sentences(
            valid_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_valid_out.txt').format(epoch_id))
        write_sentences(
            test_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_params(save_path)
        save_path = os.path.join(args.save_dir,
                                 'epoch{:d}.params'.format(epoch_id))
        model.save_params(save_path)
    save_path = os.path.join(args.save_dir, 'average.params')
    mx.nd.save(save_path, average_param_dict)
    if args.average_checkpoint:
        for j in range(args.num_averages):
            params = mx.nd.load(
                os.path.join(args.save_dir,
                             'epoch{:d}.params'.format(args.epochs - j - 1)))
            alpha = 1. / (j + 1)
            for k, v in model._collect_params_with_prefix().items():
                for c in ctx:
                    v.data(c)[:] += alpha * (params[k].as_in_context(c) -
                                             v.data(c))
    elif args.average_start > 0:
        for k, v in model.collect_params().items():
            v.set_data(average_param_dict[k])
    else:
        model.load_params(os.path.join(args.save_dir, 'valid_best.params'),
                          ctx)
    valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
    valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                valid_translation_out,
                                                bpe=True,
                                                split_compound_word=True)
    logging.info(
        'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'.
        format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                               test_translation_out,
                                               bpe=True,
                                               split_compound_word=True)
    logging.info(
        'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
        format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(args.save_dir, 'best_valid_out.txt'))
    write_sentences(test_translation_out,
                    os.path.join(args.save_dir, 'best_test_out.txt'))
示例#14
0
def train(args):
    _, num_parts, rank, local_rank, _, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    if args.comm_backend == 'horovod':
        logging_config(
            args.save_dir,
            name=f'train_transformer_rank{rank}_local{local_rank}_{num_parts}',
            console=(rank == 0))
        logging.info(args)
    else:
        logging_config(args.save_dir, name='train_transformer', console=True)
        logging.info(args)
    use_amp = args.fp16
    if use_amp:
        from mxnet import amp
    src_tokenizer = create_tokenizer(args.src_tokenizer,
                                     args.src_subword_model_path,
                                     args.src_vocab_path)
    tgt_tokenizer = create_tokenizer(args.tgt_tokenizer,
                                     args.tgt_subword_model_path,
                                     args.tgt_vocab_path)
    base_tgt_tokenizer = MosesTokenizer(args.tgt_lang)
    src_vocab = src_tokenizer.vocab
    tgt_vocab = tgt_tokenizer.vocab
    train_src_data, train_tgt_data = load_dataset_with_cache(
        args.train_src_corpus,
        args.train_tgt_corpus,
        src_tokenizer,
        tgt_tokenizer,
        args.overwrite_cache,
        local_rank,
        max_src_length=args.max_src_length,
        max_tgt_length=args.max_tgt_length,
        pretokenized=not args.tokenize)
    dev_src_data, dev_tgt_data = load_dataset_with_cache(
        args.dev_src_corpus,
        args.dev_tgt_corpus,
        src_tokenizer,
        tgt_tokenizer,
        args.overwrite_cache,
        local_rank,
        pretokenized=not args.tokenize)
    tgt_detok_sentences = []
    tgt_raw_sentences = []
    with open(args.dev_tgt_corpus, 'r') as in_f:
        for line in in_f:
            tgt_detok_sentences.append(
                base_tgt_tokenizer.decode(
                    tgt_tokenizer.decode(line.split()).split()))
    with open(args.dev_tgt_raw_corpus, 'r') as in_f:
        for line in in_f:
            tgt_raw_sentences.append(line.strip())
    data_train = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))
    ])
    val_samples = [
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))
    ]
    if args.comm_backend == 'horovod':
        slice_begin = rank * (len(val_samples) // num_parts)
        slice_end = min((rank + 1) * (len(val_samples) // num_parts),
                        len(val_samples))
        data_val = gluon.data.SimpleDataset(val_samples[slice_begin:slice_end])
    else:
        data_val = gluon.data.SimpleDataset(val_samples)
    # Construct the model + loss function
    if args.cfg.endswith('.yml'):
        cfg = TransformerModel.get_cfg().clone_merge(args.cfg)
    else:
        cfg = TransformerModel.get_cfg(args.cfg)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = len(src_vocab)
    cfg.MODEL.tgt_vocab_size = len(tgt_vocab)
    cfg.MODEL.layout = 'TN'
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l)
    model.hybridize()
    for v in model.collect_params().values():
        if v.grad_req != 'null':
            v.grad_req = 'add'
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    param_dict = deduplicate_param_dict(model.collect_params())

    inference_model = TransformerInference(model=model)
    inference_model.hybridize()
    if local_rank == 0:
        logging.info(model)
    with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f:
        cfg_f.write(cfg.dump())
    label_smooth_loss = LabelSmoothCrossEntropyLoss(
        num_labels=len(tgt_vocab),
        alpha=args.label_smooth_alpha,
        from_logits=False)
    label_smooth_loss.hybridize()

    # Construct the beam search sampler
    scorer = BeamSearchScorer(alpha=args.lp_alpha,
                              K=args.lp_k,
                              from_logits=False)
    beam_search_sampler = BeamSearchSampler(beam_size=args.beam_size,
                                            decoder=inference_model,
                                            vocab_size=len(tgt_vocab),
                                            eos_id=tgt_vocab.eos_id,
                                            scorer=scorer,
                                            stochastic=False,
                                            max_length_a=args.max_length_a,
                                            max_length_b=args.max_length_b)

    logging.info(beam_search_sampler)
    if args.comm_backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    # Construct the trainer
    if args.lr is None:
        base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt(
            args.warmup_steps)
    else:
        base_lr = args.lr
    lr_scheduler = InverseSquareRootScheduler(
        warmup_steps=args.warmup_steps,
        base_lr=base_lr,
        warmup_init_lr=args.warmup_init_lr)
    optimizer_params = {
        'learning_rate': args.lr,
        'beta1': 0.9,
        'beta2': 0.997,
        'epsilon': 1e-9,
        'lr_scheduler': lr_scheduler,
        'wd': args.wd
    }
    user_provided_ptimizer_params = json.loads(args.optimizer_params)
    optimizer_params.update(user_provided_ptimizer_params)

    if args.fp16:
        optimizer_params.update({'multi_precision': True})
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = gluon.Trainer(param_dict,
                                args.optimizer,
                                optimizer_params,
                                update_on_kvstore=False)
    # Load Data
    if args.sampler == 'BoundedBudgetSampler':
        train_batch_sampler = BoundedBudgetSampler(
            lengths=[(ele[2], ele[3]) for ele in data_train],
            max_num_tokens=args.max_num_tokens,
            max_num_sentences=args.max_num_sentences,
            shuffle=True,
            seed=args.seed)
    elif args.sampler == 'FixedBucketSampler':
        if args.comm_backend == 'horovod':
            raise NotImplementedError(
                'FixedBucketSampler does not support horovod at present')

        if args.bucket_scheme == 'constant':
            bucket_scheme = ConstWidthBucket()
        elif args.bucket_scheme == 'linear':
            bucket_scheme = LinearWidthBucket()
        elif args.bucket_scheme == 'exp':
            bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
        else:
            raise NotImplementedError
        # TODO(sxjscience) Support auto-bucket-size tuning
        train_batch_sampler = FixedBucketSampler(lengths=[
            (ele[2], ele[3]) for ele in data_train
        ],
                                                 batch_size=args.batch_size,
                                                 num_buckets=args.num_buckets,
                                                 ratio=args.bucket_ratio,
                                                 shuffle=True,
                                                 use_average_length=True,
                                                 bucket_scheme=bucket_scheme,
                                                 seed=args.seed)
    else:
        raise NotImplementedError

    num_updates_per_epoch = int(
        math.ceil(
            len(train_batch_sampler) /
            (num_parts * len(ctx_l) * args.num_accumulated)))
    # Convert the batch sampler to multiple shards
    if num_parts > 1:
        train_batch_sampler = ShardedIterator(train_batch_sampler,
                                              num_parts=num_parts,
                                              part_index=rank,
                                              even_size=True,
                                              seed=args.seed + 1000 * rank)

    logging.info(train_batch_sampler)

    batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(),
                           bf.Stack())
    train_data_loader = gluon.data.DataLoader(
        data_train,
        batch_sampler=train_batch_sampler,
        batchify_fn=batchify_fn,
        num_workers=0)
    val_data_loader = gluon.data.DataLoader(data_val,
                                            batch_size=args.val_batch_size,
                                            batchify_fn=batchify_fn,
                                            num_workers=0,
                                            shuffle=False)
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    model_averager = AverageSGDTracker(param_dict)
    log_start_time = time.time()
    num_params, num_fixed_params = None, None

    # TODO(sxjscience) Add a log metric class
    log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
    # Maintain the denominator of the loss.
    log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
    log_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l]
    log_tgt_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l]
    log_avg_grad_norm = 0
    log_iter_num = 0

    if local_rank == 0:
        writer = SummaryWriter(
            logdir=os.path.join(args.save_dir, 'tensorboard'))
    if use_amp:
        amp.init_trainer(trainer)
    train_multi_data_loader = grouper(repeat(train_data_loader), len(ctx_l))
    # when args.epochs < 0, the model will keep training
    if args.epochs < 0:
        if args.max_update > 0:
            total_train_iters = args.max_update
            if args.num_averages > 0:
                assert args.num_averages <= total_train_iters // args.save_iterval_update
                avg_start_iter = (
                    total_train_iters // args.save_iterval_update -
                    args.num_averages) * args.save_iterval_update
            else:
                avg_start_iter = -1
        else:
            total_train_iters = np.inf
            avg_start_iter = -1
    else:
        total_train_iters = args.epochs * num_updates_per_epoch
        if args.num_averages > 0:
            assert args.num_averages <= args.epochs
            avg_start_iter = (args.epochs -
                              args.num_average) * num_updates_per_epoch
        else:
            avg_start_iter = -1

    # Here, we are manually setting up the scale to 1.0 because
    # in horovod, the scale can be the number of workers:
    # See the code here: https://github.com/horovod/horovod/blob/125115583b7029196e2ec530decd4209459d5479/horovod/mxnet/__init__.py#L141
    # Since we will need to use the dynamic scaling in amp, we will manually call amp.unscale().
    # A scale that is larger than 1.0 can be problematic in this case.
    trainer._scale = 1.0
    if args.max_num_tokens > 0:
        const_scale = args.max_num_tokens
    else:
        const_scale = 100

    train_start_time = time.time()

    for train_iter in range(total_train_iters):
        model.zero_grad()
        loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
        for i in range(args.num_accumulated):
            loss_l = []
            sample_data_l = next(train_multi_data_loader)
            for j, (sample_data, ctx) in enumerate(zip(sample_data_l, ctx_l)):
                src_token_ids, tgt_token_ids, src_valid_length,\
                tgt_valid_length, sample_ids = sample_data
                src_token_ids = src_token_ids.as_in_ctx(ctx)
                tgt_token_ids = tgt_token_ids.as_in_ctx(ctx)
                src_valid_length = src_valid_length.as_in_ctx(ctx)
                tgt_valid_length = tgt_valid_length.as_in_ctx(ctx)
                src_wc, tgt_wc, bs = src_valid_length.sum(), \
                                     tgt_valid_length.sum(), src_token_ids.shape[0]
                log_wc_l[j] += src_wc + tgt_wc
                log_tgt_wc_l[j] += tgt_wc
                token_count = (tgt_valid_length - 1).sum()
                loss_denom_l[j] += token_count / const_scale
                log_avg_loss_denom_l[j] += token_count / const_scale
                with mx.autograd.record():
                    if model.layout == 'NT':
                        tgt_pred = model(src_token_ids, src_valid_length,
                                         tgt_token_ids[:, :-1],
                                         tgt_valid_length - 1)
                        tgt_labels = tgt_token_ids[:, 1:]
                        loss = label_smooth_loss(tgt_pred, tgt_labels)
                        loss = mx.npx.sequence_mask(
                            loss,
                            sequence_length=tgt_valid_length - 1,
                            use_sequence_length=True,
                            axis=1)
                        loss = loss.sum() / const_scale
                        loss_l.append(loss)
                    elif model.layout == 'TN':
                        tgt_pred = model(src_token_ids.T, src_valid_length,
                                         tgt_token_ids.T[:-1, :],
                                         tgt_valid_length - 1)
                        tgt_labels = tgt_token_ids.T[1:, :]
                        loss = label_smooth_loss(tgt_pred, tgt_labels)
                        loss = mx.npx.sequence_mask(
                            loss,
                            sequence_length=tgt_valid_length - 1,
                            use_sequence_length=True,
                            axis=0)
                        loss = loss.sum() / const_scale
                        loss_l.append(loss)
                log_avg_loss_l[j] += loss
            if use_amp:
                with mx.autograd.record():
                    with amp.scale_loss(loss_l, trainer) as amp_loss_l:
                        for loss in amp_loss_l:
                            loss.backward()
            else:
                with mx.autograd.record():
                    for loss in loss_l:
                        loss.backward()

        # Print the total number of parameters
        if local_rank == 0 and num_params is None:
            num_params, num_fixed_params = count_parameters(param_dict)
            logging.info(
                'Total Number of Parameters (not-fixed/fixed): {}/{}'.format(
                    num_params, num_fixed_params))
        # All-Reduce the gradient
        trainer.allreduce_grads()
        if args.comm_backend == 'horovod':
            # All-Reduce the loss denominator
            assert len(loss_denom_l) == 1
            loss_denom = hvd.allreduce(loss_denom_l[0],
                                       average=False).asnumpy()
        else:
            loss_denom = sum([ele.asnumpy() for ele in loss_denom_l])
        if use_amp:
            # We need to first unscale the gradient and then perform allreduce.
            grad_scale = trainer.amp_loss_scale * loss_denom
        else:
            grad_scale = loss_denom
        if args.max_grad_norm is not None:
            total_norm, ratio, is_finite\
                = clip_grad_global_norm(params, args.max_grad_norm * grad_scale)
            total_norm = total_norm / grad_scale
        else:
            total_norm = grad_global_norm(params)
            total_norm = total_norm / grad_scale
        log_avg_grad_norm += total_norm
        log_iter_num += 1

        trainer.update(loss_denom, ignore_stale_grad=True)

        if avg_start_iter > 0 and train_iter >= avg_start_iter:
            model_averager.step()

        if ((train_iter + 1) % args.log_interval == 0
                or train_iter + 1 == total_train_iters):
            if args.comm_backend == 'horovod':
                # Use allreduce to get the total number of tokens and loss
                log_wc = hvd.allreduce(log_wc_l[0], average=False).asnumpy()
                log_tgt_wc = hvd.allreduce(log_tgt_wc_l[0],
                                           average=False).asnumpy()
                log_avg_loss = hvd.allreduce(log_avg_loss_l[0] /
                                             log_avg_loss_denom_l[0],
                                             average=True)
                log_avg_loss = log_avg_loss.asnumpy()
            else:
                log_wc = sum([ele.asnumpy() for ele in log_wc_l])
                log_tgt_wc = sum([ele.asnumpy() for ele in log_tgt_wc_l])
                log_avg_loss =\
                    sum([log_avg_loss_l[i].asnumpy() / log_avg_loss_denom_l[i].asnumpy()
                         for i in range(len(log_avg_loss_l))]) / len(log_avg_loss_l)
            log_avg_grad_norm = log_avg_grad_norm / log_iter_num
            log_end_time = time.time()
            wps = log_wc / (log_end_time - log_start_time)
            epoch_id = train_iter // num_updates_per_epoch
            logging.info(
                '[Epoch {} Iter {}/{}, Overall {}/{}] loss={:.4f}, ppl={:.4f}, '
                'throughput={:.2f}K wps, total wc={:.2f}K, wpb={:.2f}K,'
                ' LR={}, gnorm={:.4f}, ETA={:.2f}h'.format(
                    epoch_id, train_iter % num_updates_per_epoch + 1,
                    num_updates_per_epoch,
                    train_iter + 1, total_train_iters, log_avg_loss,
                    np.exp(log_avg_loss), wps / 1000, log_wc / 1000,
                    log_tgt_wc / 1000 / log_iter_num, trainer.learning_rate,
                    log_avg_grad_norm,
                    (log_end_time - train_start_time) / (train_iter + 1) *
                    (total_train_iters - train_iter - 1) / 3600))
            if local_rank == 0:
                writer.add_scalar('throughput_wps', wps, train_iter)
                writer.add_scalar('train_loss', log_avg_loss, train_iter)
                writer.add_scalar('lr', trainer.learning_rate, train_iter)
                writer.add_scalar('grad_norm', log_avg_grad_norm, train_iter)
            # Reinitialize the log variables
            log_start_time = time.time()
            log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
            log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
            log_avg_grad_norm = 0
            log_iter_num = 0
            log_wc_l = [
                mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l
            ]
            log_tgt_wc_l = [
                mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l
            ]

        if (args.max_update > 0 and (train_iter + 1) % args.save_interval_update == 0) \
            or ((train_iter + 1) % num_updates_per_epoch == 0) \
            or train_iter + 1 == total_train_iters:
            epoch_id = (train_iter + 1) // num_updates_per_epoch
            if local_rank == 0:
                if args.max_update <= 0:
                    model.save_parameters(os.path.join(
                        args.save_dir, 'epoch{}.params'.format(epoch_id)),
                                          deduplicate=True)
                else:
                    model.save_parameters(os.path.join(
                        args.save_dir, 'iter{}.params'.format(train_iter + 1)),
                                          deduplicate=True)

            avg_val_loss, ntokens, pred_sentences, pred_lengths, sentence_ids\
                = validation(model, val_data_loader, inference_model, beam_search_sampler,
                             tgt_tokenizer, ctx_l)
            if args.comm_backend == 'horovod':
                flatten_pred_sentences = np.concatenate(pred_sentences, axis=0)
                all_val_loss = hvd.allgather(
                    mx.np.array([avg_val_loss * ntokens],
                                dtype=np.float32,
                                ctx=ctx_l[0]))
                all_ntokens = hvd.allgather(
                    mx.np.array([ntokens], dtype=np.int64, ctx=ctx_l[0]))
                flatten_pred_sentences = hvd.allgather(
                    mx.np.array(flatten_pred_sentences,
                                dtype=np.int32,
                                ctx=ctx_l[0]))
                pred_lengths = hvd.allgather(
                    mx.np.array(pred_lengths, dtype=np.int64, ctx=ctx_l[0]))
                sentence_ids = hvd.allgather(
                    mx.np.array(sentence_ids, dtype=np.int64, ctx=ctx_l[0]))
                avg_val_loss = all_val_loss.asnumpy().sum(
                ) / all_ntokens.asnumpy().sum()
                flatten_pred_sentences = flatten_pred_sentences.asnumpy()
                pred_lengths = pred_lengths.asnumpy()
                sentence_ids = sentence_ids.asnumpy()
                pred_sentences = [None for _ in range(len(sentence_ids))]
                ptr = 0
                assert sentence_ids.min() == 0 and sentence_ids.max(
                ) == len(sentence_ids) - 1
                for sentence_id, length in zip(sentence_ids, pred_lengths):
                    pred_sentences[sentence_id] = flatten_pred_sentences[ptr:(
                        ptr + length)]
                    ptr += length
            if local_rank == 0:
                # Perform detokenization
                pred_sentences_bpe_decode = []
                pred_sentences_raw = []
                for sentence in pred_sentences:
                    bpe_decode_sentence = tgt_tokenizer.decode(
                        sentence.tolist())
                    raw_sentence = base_tgt_tokenizer.decode(
                        bpe_decode_sentence.split())
                    pred_sentences_bpe_decode.append(bpe_decode_sentence)
                    pred_sentences_raw.append(raw_sentence)
                detok_sacrebleu_out = sacrebleu.corpus_bleu(
                    sys_stream=pred_sentences_bpe_decode,
                    ref_streams=[tgt_detok_sentences])
                raw_sacrebleu_out = sacrebleu.corpus_bleu(
                    sys_stream=pred_sentences_raw,
                    ref_streams=[tgt_raw_sentences])
                with open(
                        os.path.join(args.save_dir,
                                     f'epoch{epoch_id}_dev_prediction.txt'),
                        'w') as of:
                    for line in pred_sentences_raw:
                        of.write(line + '\n')
                logging.info(
                    '[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, '
                    'SacreBlEU={}, Detok SacreBLUE={}'.format(
                        epoch_id, train_iter, total_train_iters, avg_val_loss,
                        np.exp(avg_val_loss), raw_sacrebleu_out.score,
                        detok_sacrebleu_out.score))
                writer.add_scalar('valid_loss', avg_val_loss, train_iter)
                writer.add_scalar('valid_bleu', raw_sacrebleu_out.score,
                                  train_iter)

    if args.num_averages > 0:
        model_averager.copy_back(
            param_dict)  # TODO(sxjscience) Rewrite using update
        model.save_parameters(os.path.join(args.save_dir, 'average.params'),
                              deduplicate=True)
    def bin(self,
            corpus: Corpus,
            temp: float = 1.0,
            split_size: int = 2000,
            ratio: float = 0,
            num_workers: int = 10) -> List[float]:

        ctx_cpu = mx.Context('cpu')

        # Turn corpus into a BERT-ready Dataset
        dataset = self.corpus_to_dataset(corpus)

        # Turn Dataset into Dataloader
        batchify_fn = btf.Tuple(
            btf.Stack(dtype='int32'),
            btf.Pad(
                pad_val=self._vocab.token_to_idx[self._vocab.padding_token],
                dtype='int32'), btf.Stack(dtype='float32'),
            btf.Stack(dtype='int32'), btf.Stack(dtype='int32'),
            btf.Stack(dtype='float32'))

        # TODO: There is a 'by-design' bug in FixedBucketSampler with num_shards > 0, where it silently reuses the last utterances:
        # https://github.com/dmlc/gluon-nlp/blame/b1b61d3f90cf795c7b48b6d109db7b7b96fa21ff/src/gluonnlp/data/sampler.py#L398
        # batch_sampler = nlp.data.sampler.FixedBucketSampler([sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=len(self._ctxs), shuffle=False)
        # Hence, we use num_shards = 0 and do gluon's split_data
        batch_sampler = nlp.data.sampler.FixedBucketSampler(
            [sent_tuple[2] for sent_tuple in dataset],
            batch_size=split_size,
            ratio=ratio,
            num_shards=0,
            shuffle=False)

        logging.info(batch_sampler.stats())
        dataloader = nlp.data.ShardedDataLoader(dataset,
                                                pin_memory=True,
                                                batch_sampler=batch_sampler,
                                                batchify_fn=batchify_fn,
                                                num_workers=num_workers,
                                                thread_pool=True)

        ### <DIFFERENT>
        max_length = 256
        # Compute bins
        # First axis is sentence length
        bin_counts = np.zeros((max_length, max_length))
        bin_counts_per_ctx = [
            mx.nd.zeros((max_length, max_length), ctx=ctx)
            for ctx in self._ctxs
        ]
        bin_sums = np.zeros((max_length, max_length))
        bin_sums_per_ctx = [
            mx.nd.zeros((max_length, max_length), ctx=ctx)
            for ctx in self._ctxs
        ]
        ### </DIFFERENT>

        # Compute sum (assumes dataset is in order)
        prev_sent_idx = None
        true_tok_lens = []
        for (curr_sent_idx, _, valid_length, _, _, _) in dataset:
            if curr_sent_idx != prev_sent_idx:
                prev_sent_idx = curr_sent_idx
                true_tok_lens.append(valid_length - 2)

        sent_count = 0
        batch_log_interval = 20

        # For now just predicts the first non-cls token
        for batch_id, batch in enumerate(dataloader):

            batch_size = 0

            # TODO: Write tests about batching over multiple GPUs and getting the same scores
            # TODO: SEE COMMENT ABOVE REGARDING FIXEDBUCKETSAMPLER
            batch = zip(*[
                mx.gluon.utils.split_data(batch_compo,
                                          len(self._ctxs),
                                          batch_axis=0,
                                          even_split=False)
                for batch_compo in batch
            ])

            for ctx_idx, (sent_idxs, token_ids, valid_length, masked_positions,
                          token_masked_ids, normalization) in enumerate(batch):

                ctx = self._ctxs[ctx_idx]
                batch_size += sent_idxs.shape[0]
                token_ids = token_ids.as_in_context(ctx)
                valid_length = valid_length.as_in_context(ctx)
                segment_ids = mx.nd.zeros(shape=token_ids.shape, ctx=ctx)
                masked_positions = masked_positions.as_in_context(ctx).reshape(
                    -1, 1)
                out = self._model(token_ids, segment_ids, valid_length,
                                  masked_positions)

                # Get the probability computed for the correct token
                split_size = token_ids.shape[0]
                # out[0] contains the representations
                # out[1] is what contains the distribution for the masked
                out = out[1].log_softmax(temperature=temp)

                ### <DIFFERENT>
                token_masked_ids = token_masked_ids.as_in_context(ctx).reshape(
                    -1)
                for i in range(out.shape[0]):
                    num_bins = int(valid_length[i].asscalar()) - 2
                    bin_counts_per_ctx[ctx_idx][num_bins,
                                                masked_positions[i] - 1] += 1
                    bin_sums_per_ctx[ctx_idx][num_bins, masked_positions[i] -
                                              1] += out[i, 0,
                                                        token_masked_ids[i]]
                    if token_masked_ids[i].asscalar() == 1012:
                        import pdb
                        pdb.set_trace()
                ### </DIFFERENT>

            # Progress
            sent_count += batch_size
            if (batch_id + 1) % batch_log_interval == 0:
                logging.info("{} sents of {}, batch {} of {}".format(
                    sent_count, len(dataset), batch_id + 1,
                    len(batch_sampler)))

        # Accumulate the counts
        for ctx_idx in range(len(self._ctxs)):
            bin_counts += bin_counts_per_ctx[ctx_idx].asnumpy()
            bin_sums += bin_sums_per_ctx[ctx_idx].asnumpy()

        return bin_counts, bin_sums
    def score(self,
              corpus: Corpus,
              temp: float = 1.0,
              split_size: int = 2000,
              ratio: float = 0,
              num_workers: int = 10,
              per_token: bool = False) -> List[float]:

        ctx_cpu = mx.Context('cpu')

        # Turn corpus into a BERT-ready Dataset
        dataset = self.corpus_to_dataset(corpus)

        # Turn Dataset into Dataloader
        batchify_fn = btf.Tuple(
            btf.Stack(dtype='int32'),
            btf.Pad(
                pad_val=self._vocab.token_to_idx[self._vocab.padding_token],
                dtype='int32'), btf.Stack(dtype='float32'),
            btf.Stack(dtype='float32'), btf.Stack(dtype='int32'),
            btf.Stack(dtype='float32'))

        # TODO: There is a 'by-design' bug in FixedBucketSampler with num_shards > 0, where it silently reuses the last utterances:
        # https://github.com/dmlc/gluon-nlp/blame/b1b61d3f90cf795c7b48b6d109db7b7b96fa21ff/src/gluonnlp/data/sampler.py#L398
        # batch_sampler = nlp.data.sampler.FixedBucketSampler([sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=len(self._ctxs), shuffle=False)
        # Hence, we use num_shards = 0 and do gluon's split_data
        batch_sampler = nlp.data.sampler.FixedBucketSampler(
            [sent_tuple[2] for sent_tuple in dataset],
            batch_size=split_size,
            ratio=ratio,
            num_shards=0,
            shuffle=False)

        logging.info(batch_sampler.stats())
        dataloader = nlp.data.ShardedDataLoader(dataset,
                                                pin_memory=True,
                                                batch_sampler=batch_sampler,
                                                batchify_fn=batchify_fn,
                                                num_workers=num_workers,
                                                thread_pool=True)

        # Get lengths in tokens (assumes dataset is in order)
        prev_sent_idx = None
        true_tok_lens = []
        for (curr_sent_idx, _, valid_length, _, _, _) in dataset:
            if curr_sent_idx != prev_sent_idx:
                prev_sent_idx = curr_sent_idx
                true_tok_lens.append(valid_length - 2)

        # Compute scores (total or per-position)
        if per_token:
            scores_per_token = [[None] * (true_tok_len + 2)
                                for true_tok_len in true_tok_lens]
        else:
            scores = np.zeros((len(corpus), ))

        sent_count = 0
        batch_log_interval = 20

        batch_score_accumulation = 1
        batch_sent_idxs_per_ctx = [[] for ctx in self._ctxs]
        batch_scores_per_ctx = [[] for ctx in self._ctxs]
        batch_masked_positions_per_ctx = [[] for ctx in self._ctxs]

        def sum_accumulated_scores():
            for ctx_idx in range(len(self._ctxs)):
                for batch_sent_idxs, batch_scores, batch_masked_positions in zip(
                        batch_sent_idxs_per_ctx[ctx_idx],
                        batch_scores_per_ctx[ctx_idx],
                        batch_masked_positions_per_ctx[ctx_idx]):
                    if per_token:
                        # Slow; only use when necessary
                        for batch_sent_idx, batch_score, batch_masked_position in zip(
                                batch_sent_idxs, batch_scores,
                                batch_masked_positions):
                            scores_per_token[batch_sent_idx.asscalar()][int(
                                batch_masked_position.asscalar(
                                ))] = batch_score.asscalar().item()
                    else:
                        np.add.at(scores, batch_sent_idxs.asnumpy(),
                                  batch_scores.asnumpy())
                batch_sent_idxs_per_ctx[ctx_idx] = []
                batch_scores_per_ctx[ctx_idx] = []
                batch_masked_positions_per_ctx[ctx_idx] = []

        # For now just predicts the first non-cls token
        for batch_id, batch in enumerate(dataloader):

            batch_size = 0

            # TODO: Write tests about batching over multiple GPUs and getting the same scores
            # TODO: SEE COMMENT ABOVE REGARDING FIXEDBUCKETSAMPLER
            batch = zip(*[
                mx.gluon.utils.split_data(batch_compo,
                                          len(self._ctxs),
                                          batch_axis=0,
                                          even_split=False)
                for batch_compo in batch
            ])

            for ctx_idx, (sent_idxs, token_ids, valid_length, masked_positions,
                          token_masked_ids, normalization) in enumerate(batch):

                ctx = self._ctxs[ctx_idx]
                batch_size += sent_idxs.shape[0]
                token_ids = token_ids.as_in_context(ctx)
                valid_length = valid_length.as_in_context(ctx)
                masked_positions = masked_positions.as_in_context(ctx).reshape(
                    -1, 1)

                if isinstance(self._model, RoBERTaModel):
                    out = self._model(token_ids, valid_length,
                                      masked_positions)
                else:
                    segment_ids = mx.nd.zeros(shape=token_ids.shape, ctx=ctx)
                    out = self._model(token_ids, segment_ids, valid_length,
                                      masked_positions)

                # Get the probability computed for the correct token
                split_size = token_ids.shape[0]
                # out[0] contains the representations
                # out[1] is what contains the distribution for the masked

                # TODO: Manual numerically-stable softmax
                # https://stackoverflow.com/questions/42599498/numercially-stable-softmax
                # Because we only need one scalar
                out = out[1].log_softmax(temperature=temp)

                # Save the scores at the masked indices
                batch_sent_idxs_per_ctx[ctx_idx].append(sent_idxs)
                out = out[list(range(split_size)), [0] * split_size,
                          token_masked_ids.as_in_context(ctx).reshape(-1)]
                batch_scores_per_ctx[ctx_idx].append(out)
                batch_masked_positions_per_ctx[ctx_idx].append(
                    masked_positions)

            # Ideally we'd accumulate the scores when possible, but something like the below won't work
            # > scores[sent_idxs] += out
            # See In[21] in https://jakevdp.github.io/PythonDataScienceHandbook/02.07-fancy-indexing.html.
            # Hence, aggregation is done synchronously, every so often
            # (though batch_score_accumulation = 1 seems best, since bucketing is effective in reducing GPU disparity)
            if len(batch_sent_idxs_per_ctx[0]) == batch_score_accumulation:
                sum_accumulated_scores()

            # Progress
            sent_count += batch_size
            if (batch_id + 1) % batch_log_interval == 0:
                logging.info("{} sents of {}, batch {} of {}".format(
                    sent_count, len(dataset), batch_id + 1,
                    len(batch_sampler)))

        # TODO: Test score accumulation
        # In case there are leftovers
        sum_accumulated_scores()

        if per_token:
            return scores_per_token, true_tok_lens
        else:
            return scores.tolist(), true_tok_lens
示例#17
0
def get_pretrain_data_text(data,
                           batch_size,
                           shuffle,
                           num_buckets,
                           tokenizer,
                           vocab,
                           max_seq_length,
                           short_seq_prob=0.05,
                           num_parts=1,
                           part_idx=0,
                           num_dataset_workers=1,
                           num_batch_workers=1,
                           circle_length=1,
                           repeat=1,
                           cached_file_path=None):
    """Get a data iterator from raw text documents.

    Parameters
    ----------
    batch_size : int
        The batch size per GPU.
    shuffle : bool
        Whether to shuffle the data.
    num_buckets : int
        The number of buckets for the FixedBucketSampler for training.
    vocab : Vocab
        The vocabulary.
    tokenizer : HuggingFaceWordPieceTokenizer or SentencepieceTokenizer
        The tokenizer.
    max_seq_length : int
        The hard limit of maximum sequence length of sentence pairs.
    short_seq_prob : float
        The probability of sampling sequences shorter than the max_seq_length.
    num_parts : int
        The number of partitions for the dataset.
    part_idx : int
        The index of the partition to read.
    num_dataset_workers : int
        The number of worker processes for dataset construction.
    num_batch_workers : int
        The number of worker processes for batch construction.
    circle_length : int, default is 1
        The number of files to be read for a single worker at the same time.
        When circle_length is larger than 1, we merge circle_length files.
    repeat : int, default is 1
        The number of times that files are repeated.
    cached_file_path: str, default is None
        Directory for saving preprocessed features
    """
    num_files = len(glob(data))
    logging.info('%d files are found.', num_files)
    assert num_files >= num_parts, \
        'The number of text files must be no less than the number of ' \
        'workers/partitions (%d). Only %d files at %s are found.' % (num_parts, num_files, data)
    split_sampler = SplitSampler(num_files,
                                 num_parts=num_parts,
                                 part_index=part_idx,
                                 repeat=repeat)
    dataset_fn = prepare_pretrain_text_dataset
    sampler_fn = prepare_pretrain_bucket_sampler
    dataset_params = {
        'tokenizer': tokenizer,
        'max_seq_length': max_seq_length,
        'short_seq_prob': short_seq_prob,
        'cached_file_path': cached_file_path
    }
    sampler_params = {
        'batch_size': batch_size,
        'shuffle': shuffle,
        'num_buckets': num_buckets
    }
    batchify_fn = bf.Tuple(
        bf.Pad(val=vocab.pad_id),  # input_ids
        bf.Pad(val=0),  # segment_ids
        bf.Stack(),  # valid_lengths
    )

    dataloader = DatasetLoader(data,
                               file_sampler=split_sampler,
                               dataset_fn=dataset_fn,
                               batch_sampler_fn=sampler_fn,
                               dataset_params=dataset_params,
                               batch_sampler_params=sampler_params,
                               batchify_fn=batchify_fn,
                               num_dataset_workers=num_dataset_workers,
                               num_batch_workers=num_batch_workers,
                               pin_memory=False,
                               circle_length=circle_length)
    return dataloader
示例#18
0
def get_pretrain_data_npz(data,
                          batch_size,
                          shuffle,
                          num_buckets,
                          vocab,
                          num_parts=1,
                          part_idx=0,
                          num_dataset_workers=1,
                          num_batch_workers=1,
                          circle_length=1,
                          repeat=1,
                          dataset_cached=False,
                          num_max_dataset_cached=0):
    """Get a data iterator from pre-processed npz files.

    Parameters
    ----------
    data: str
        The path to the dataset directory
    batch_size : int
        The batch size per GPU.
    shuffle : bool
        Whether to shuffle the data.
    num_buckets : int
        The number of buckets for the FixedBucketSampler for training.
    vocab : Vocab
        The vocabulary.
    num_parts : int
        The number of partitions for the dataset.
    part_idx : int
        The index of the partition to read.
    num_dataset_workers : int
        The number of worker processes for dataset construction.
    num_batch_workers : int
        The number of worker processes for batch contruction.
    circle_length : int, default is 1
        The number of files to be read for a single worker at the same time.
        When circle_length is larger than 1, we merge circle_length files.
    repeat : int, default is 1
        The number of times that files are repeated.
    dataset_cached : bool, default is False
        Whether or not to cache last processed dataset.
        Each processed dataset can only be cached for once.
        When there is no new available processed dataset to be fetched,
        we pop a cached processed dataset.
    num_max_dataset_cached : int, default is 0
        Maximum number of cached datasets. It is valid only if dataset_cached is True
    """
    num_files = len(glob(data))
    logging.info('%d files are found.', num_files)
    assert num_files >= num_parts, \
        'The number of text files must be no less than the number of ' \
        'workers/partitions (%d). Only %d files at %s are found.' % (num_parts, num_files, data)
    split_sampler = SplitSampler(num_files,
                                 num_parts=num_parts,
                                 part_index=part_idx,
                                 repeat=repeat)
    dataset_fn = prepare_pretrain_npz_dataset
    sampler_fn = prepare_pretrain_bucket_sampler
    dataset_params = {'allow_pickle': True}
    sampler_params = {
        'batch_size': batch_size,
        'shuffle': shuffle,
        'num_buckets': num_buckets
    }
    batchify_fn = bf.Tuple(
        bf.Pad(val=vocab.pad_id),  # input_ids
        bf.Pad(val=0),  # segment_ids
        bf.Stack(),  # valid_lengths
    )
    dataloader = DatasetLoader(data,
                               file_sampler=split_sampler,
                               dataset_fn=dataset_fn,
                               batch_sampler_fn=sampler_fn,
                               dataset_params=dataset_params,
                               batch_sampler_params=sampler_params,
                               batchify_fn=batchify_fn,
                               num_dataset_workers=num_dataset_workers,
                               num_batch_workers=num_batch_workers,
                               pin_memory=False,
                               circle_length=circle_length)
    return dataloader
示例#19
0
def get_pretrain_data_text(data, batch_size, shuffle, num_buckets, vocab, tokenizer,
                           max_seq_length, short_seq_prob, masked_lm_prob,
                           max_predictions_per_seq, whole_word_mask, random_next_sentence,
                           num_parts=1, part_idx=0, num_dataset_workers=1, num_batch_workers=1,
                           circle_length=1, repeat=1,
                           dataset_cached=False, num_max_dataset_cached=0):
    """Get a data iterator from raw text documents.

    Parameters
    ----------
    batch_size : int
        The batch size per GPU.
    shuffle : bool
        Whether to shuffle the data.
    num_buckets : int
        The number of buckets for the FixedBucketSampler for training.
    vocab : Vocab
        The vocabulary.
    tokenizer : BaseTokenizer
        The tokenizer.
    max_seq_length : int
        The hard limit of maximum sequence length of sentence pairs.
    short_seq_prob : float
        The probability of sampling sequences shorter than the max_seq_length.
    masked_lm_prob : float
        The probability of replacing texts with masks/random words/original words.
    max_predictions_per_seq : int
        The hard limit of the number of predictions for masked words
    whole_word_mask : bool
        Whether to use whole word masking.
    num_parts : int
        The number of partitions for the dataset.
    part_idx : int
        The index of the partition to read.
    num_dataset_workers : int
        The number of worker processes for dataset construction.
    num_batch_workers : int
        The number of worker processes for batch construction.
    circle_length : int, default is 1
        The number of files to be read for a single worker at the same time.
        When circle_length is larger than 1, we merge circle_length files.
    repeat : int, default is 1
        The number of times that files are repeated.
    dataset_cached : bool, default is False
        Whether or not to cache last processed dataset.
        Each processed dataset can only be cached for once.
        When there is no new available processed dataset to be fetched,
        we pop a cached processed dataset.
    num_max_dataset_cached : int, default is 0
        Maximum number of cached datasets. It is valid only if dataset_cached is True
    """
    num_files = len(glob(data))
    logging.info('%d files are found.', num_files)
    assert num_files >= num_parts, \
        'The number of text files must be no less than the number of ' \
        'workers/partitions (%d). Only %d files at %s are found.'%(num_parts, num_files, data)
    dataset_params = {'tokenizer': tokenizer, 'max_seq_length': max_seq_length,
                      'short_seq_prob': short_seq_prob, 'masked_lm_prob': masked_lm_prob,
                      'max_predictions_per_seq': max_predictions_per_seq, 'vocab':vocab,
                      'whole_word_mask': whole_word_mask, 'random_next_sentence': random_next_sentence}
    sampler_params = {'batch_size': batch_size, 'shuffle': shuffle, 'num_buckets': num_buckets}
    dataset_fn = prepare_pretrain_text_dataset
    sampler_fn = prepare_pretrain_bucket_sampler
    pad_val = vocab.pad_id
    batchify_fn = bf.Tuple(
        bf.Pad(val=pad_val, round_to=8),  # input_id
        bf.Pad(val=pad_val),  # masked_id
        bf.Pad(val=0),  # masked_position
        bf.Pad(val=0),  # masked_weight
        bf.Stack(),  # next_sentence_label
        bf.Pad(val=0, round_to=8),  # segment_id
        bf.Stack())  # valid_lengths
    split_sampler = SplitSampler(num_files, num_parts=num_parts,
                                 part_index=part_idx, repeat=repeat)
    dataloader = DatasetLoader(data,
                               file_sampler=split_sampler,
                               dataset_fn=dataset_fn,
                               batch_sampler_fn=sampler_fn,
                               dataset_params=dataset_params,
                               batch_sampler_params=sampler_params,
                               batchify_fn=batchify_fn,
                               num_dataset_workers=num_dataset_workers,
                               num_batch_workers=num_batch_workers,
                               pin_memory=False,
                               circle_length=circle_length,
                               dataset_cached=dataset_cached,
                               num_max_dataset_cached=num_max_dataset_cached)
    return dataloader
示例#20
0
def test_pad_wrap_batchify():
    def _verify_padded_arr(padded_arr, original_arr, pad_axis, pad_val,
                           pad_length, dtype):
        ndim = original_arr.ndim
        slices_data = [slice(None) for _ in range(ndim)]
        slices_data[pad_axis] = slice(original_arr.shape[axis])
        assert_allclose(padded_arr[tuple(slices_data)], original_arr)
        if original_arr.shape[pad_axis] < pad_length:
            slices_pad_val = [slice(None) for _ in range(ndim)]
            slices_pad_val[axis] = slice(original_arr.shape[pad_axis], None)
            pad_val_in_arr = padded_arr[tuple(slices_pad_val)]
            assert_allclose(pad_val_in_arr, (np.ones_like(pad_val_in_arr) *
                                             pad_val).astype(dtype))

    batch_size = 6
    for ndim in range(1, 3):
        for axis in range(-ndim, ndim):
            for length_min, length_max in [(3, 4), (3, 7)]:
                for pad_val in [-1, 0]:
                    for dtype in [
                            np.uint8, np.int32, np.int64, np.float16,
                            np.float32, np.float64
                    ]:
                        # Each instance contains a single array
                        for _dtype in [None, dtype]:
                            shapes = [[2 for _ in range(ndim)]
                                      for _ in range(batch_size)]
                            for i in range(len(shapes)):
                                shapes[i][axis] = np.random.randint(
                                    length_min, length_max)
                            random_data_npy = [
                                np.random.normal(0, 1, shape).astype(dtype)
                                for shape in shapes
                            ]
                            batchify_fn = batchify.Pad(axis=axis,
                                                       pad_val=pad_val,
                                                       ret_length=True,
                                                       dtype=_dtype)
                            batch_data, valid_length = batchify_fn(
                                random_data_npy)
                            batch_data_use_mx, valid_length_use_mx = batchify_fn(
                                [
                                    mx.nd.array(ele, dtype=dtype)
                                    for ele in random_data_npy
                                ])
                            assert_allclose(batch_data_use_mx.asnumpy(),
                                            batch_data.asnumpy())
                            assert_allclose(valid_length_use_mx.asnumpy(),
                                            valid_length.asnumpy())
                            assert batch_data.dtype == batch_data_use_mx.dtype == dtype
                            assert valid_length.dtype == valid_length_use_mx.dtype == np.int32
                            valid_length = valid_length.asnumpy()
                            batch_data = batch_data.asnumpy()
                            for i in range(batch_size):
                                assert (valid_length[i] == shapes[i][axis])
                                pad_length = max(shape[axis]
                                                 for shape in shapes)
                                _verify_padded_arr(batch_data[i],
                                                   random_data_npy[i], axis,
                                                   pad_val, pad_length, dtype)
                            # Each instance contains 3 arrays, we pad part of them according to index
                            TOTAL_ELE_NUM = 3
                            for pad_index in [[0], [1], [2], [0, 1], [1, 2],
                                              [0, 1, 2]]:
                                shapes = [[[2 for _ in range(ndim)]
                                           for _ in range(batch_size)]
                                          for _ in range(TOTAL_ELE_NUM)]
                                for j in pad_index:
                                    for i in range(batch_size):
                                        shapes[j][i][axis] = np.random.randint(
                                            length_min, length_max)
                                random_data_npy = [
                                    tuple(
                                        np.random.normal(0, 1, shapes[j]
                                                         [i]).astype(dtype)
                                        for j in range(TOTAL_ELE_NUM))
                                    for i in range(batch_size)
                                ]
                                batchify_fn = []
                                for j in range(TOTAL_ELE_NUM):
                                    if j in pad_index:
                                        batchify_fn.append(
                                            batchify.Pad(axis=axis,
                                                         pad_val=pad_val,
                                                         ret_length=True,
                                                         dtype=_dtype))
                                    else:
                                        batchify_fn.append(
                                            batchify.Stack(dtype=_dtype))
                                batchify_fn = batchify.Tuple(batchify_fn)
                                ret_use_npy = batchify_fn(random_data_npy)
                                ret_use_mx = batchify_fn([
                                    tuple(
                                        mx.nd.array(ele[i], dtype=dtype)
                                        for i in range(TOTAL_ELE_NUM))
                                    for ele in random_data_npy
                                ])
                                for i in range(TOTAL_ELE_NUM):
                                    if i in pad_index:
                                        assert ret_use_npy[i][
                                            0].dtype == ret_use_mx[i][
                                                0].dtype == dtype
                                        assert ret_use_npy[i][
                                            1].dtype == ret_use_mx[i][
                                                1].dtype == np.int32
                                        assert_allclose(
                                            ret_use_npy[i][0].asnumpy(),
                                            ret_use_mx[i][0].asnumpy())
                                        assert_allclose(
                                            ret_use_npy[i][1].asnumpy(),
                                            ret_use_mx[i][1].asnumpy())
                                        assert (ret_use_npy[i][1].shape == (
                                            batch_size, ))
                                    else:
                                        assert ret_use_npy[
                                            i].dtype == ret_use_mx[
                                                i].dtype == dtype
                                        assert_allclose(
                                            ret_use_npy[i].asnumpy(),
                                            ret_use_mx[i].asnumpy())
                                for i in range(batch_size):
                                    for j in range(TOTAL_ELE_NUM):
                                        if j in pad_index:
                                            batch_data, valid_length = ret_use_npy[j][0].asnumpy(), \
                                                                       ret_use_npy[j][1].asnumpy()
                                            assert (valid_length[i] ==
                                                    shapes[j][i][axis])
                                        else:
                                            batch_data = ret_use_npy[
                                                j].asnumpy()
                                        pad_length = max(
                                            ele[j].shape[axis]
                                            for ele in random_data_npy)
                                        _verify_padded_arr(
                                            batch_data[i],
                                            random_data_npy[i][j], axis,
                                            pad_val, pad_length, dtype)
                        for _dtype in [np.float16, np.float32]:
                            shapes = [[2 for _ in range(ndim)]
                                      for _ in range(batch_size)]
                            for i in range(len(shapes)):
                                shapes[i][axis] = np.random.randint(
                                    length_min, length_max)
                            random_data_npy = [
                                np.random.normal(0, 1, shape).astype(dtype)
                                for shape in shapes
                            ]
                            batchify_fn = batchify.Pad(axis=axis,
                                                       pad_val=pad_val,
                                                       ret_length=True,
                                                       dtype=_dtype)
                            batch_data, valid_length = batchify_fn(
                                random_data_npy)
                            batch_data_use_mx, valid_length_use_mx = batchify_fn(
                                [
                                    mx.nd.array(ele, dtype=dtype)
                                    for ele in random_data_npy
                                ])
                            assert_allclose(valid_length_use_mx.asnumpy(),
                                            valid_length.asnumpy())
                            assert batch_data.dtype == batch_data_use_mx.dtype == _dtype
                            assert valid_length.dtype == valid_length_use_mx.dtype == np.int32
示例#21
0
def test_stack_batchify():
    batchify_fn = batchify.Stack()
    dat = [np.random.randint(5) for _ in range(10)]
    assert_allclose(batchify_fn(dat).asnumpy(), np.array(dat))
示例#22
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer,
                            {'learning_rate': args.lr})

    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                  btf.Stack(dtype='float32'),
                                  btf.Stack(dtype='float32'))
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                 btf.Stack(dtype='float32'),
                                 btf.Stack(dtype='float32'), btf.Stack())
    if args.bucket_scheme == 'constant':
        bucket_scheme = ConstWidthBucket()
    elif args.bucket_scheme == 'linear':
        bucket_scheme = LinearWidthBucket()
    elif args.bucket_scheme == 'exp':
        bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
    else:
        raise NotImplementedError
    train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths,
                                             batch_size=args.batch_size,
                                             num_buckets=args.num_buckets,
                                             ratio=args.bucket_ratio,
                                             shuffle=True,
                                             bucket_scheme=bucket_scheme)
    logging.info('Train Batch Sampler:\n{}'.format(
        train_batch_sampler.stats()))
    train_data_loader = DataLoader(data_train,
                                   batch_sampler=train_batch_sampler,
                                   batchify_fn=train_batchify_fn,
                                   num_workers=8)

    val_batch_sampler = FixedBucketSampler(lengths=data_val_lengths,
                                           batch_size=args.test_batch_size,
                                           num_buckets=args.num_buckets,
                                           ratio=args.bucket_ratio,
                                           shuffle=False)
    logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
    val_data_loader = DataLoader(data_val,
                                 batch_sampler=val_batch_sampler,
                                 batchify_fn=test_batchify_fn,
                                 num_workers=8)
    test_batch_sampler = FixedBucketSampler(lengths=data_test_lengths,
                                            batch_size=args.test_batch_size,
                                            num_buckets=args.num_buckets,
                                            ratio=args.bucket_ratio,
                                            shuffle=False)
    logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
    test_data_loader = DataLoader(data_test,
                                  batch_sampler=test_batch_sampler,
                                  batchify_fn=test_batchify_fn,
                                  num_workers=8)
    best_valid_bleu = 0.0
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_avg_gnorm = 0
        log_wc = 0
        log_start_time = time.time()
        for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\
                in enumerate(train_data_loader):
            # logging.info(src_seq.context) Context suddenly becomes GPU.
            src_seq = src_seq.as_in_context(ctx)
            tgt_seq = tgt_seq.as_in_context(ctx)
            src_valid_length = src_valid_length.as_in_context(ctx)
            tgt_valid_length = tgt_valid_length.as_in_context(ctx)
            with mx.autograd.record():
                out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length,
                               tgt_valid_length - 1)
                loss = loss_function(out, tgt_seq[:, 1:],
                                     tgt_valid_length - 1).mean()
                loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length -
                                                        1).mean()
                loss.backward()
            grads = [p.grad(ctx) for p in model.collect_params().values()]
            gnorm = gluon.utils.clip_global_norm(grads, args.clip)
            trainer.step(1)
            src_wc = src_valid_length.sum().asscalar()
            tgt_wc = (tgt_valid_length - 1).sum().asscalar()
            step_loss = loss.asscalar()
            log_avg_loss += step_loss
            log_avg_gnorm += gnorm
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % args.log_interval == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info(
                    '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, '
                    'throughput={:.2f}K wps, wc={:.2f}K'.format(
                        epoch_id, batch_id + 1, len(train_data_loader),
                        log_avg_loss / args.log_interval,
                        np.exp(log_avg_loss / args.log_interval),
                        log_avg_gnorm / args.log_interval, wps / 1000,
                        log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_avg_gnorm = 0
                log_wc = 0
        valid_loss, valid_translation_out = evaluate(val_data_loader)
        valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                    valid_translation_out)
        logging.info(
            '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
            .format(epoch_id, valid_loss, np.exp(valid_loss),
                    valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader)
        test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                                   test_translation_out)
        logging.info(
            '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
            format(epoch_id, test_loss, np.exp(test_loss),
                   test_bleu_score * 100))
        write_sentences(
            valid_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_valid_out.txt').format(epoch_id))
        write_sentences(
            test_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_params(save_path)
        if epoch_id + 1 >= (args.epochs * 2) // 3:
            new_lr = trainer.learning_rate * args.lr_update_factor
            logging.info('Learning rate change to {}'.format(new_lr))
            trainer.set_learning_rate(new_lr)
    model.load_params(os.path.join(args.save_dir, 'valid_best.params'))
    valid_loss, valid_translation_out = evaluate(val_data_loader)
    valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                valid_translation_out)
    logging.info(
        'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'.
        format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader)
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                               test_translation_out)
    logging.info(
        'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
        format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(args.save_dir, 'best_valid_out.txt'))
    write_sentences(test_translation_out,
                    os.path.join(args.save_dir, 'best_test_out.txt'))
示例#23
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer, {
        'learning_rate': args.lr,
        'beta2': 0.98,
        'epsilon': 1e-9
    })

    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                  btf.Stack(dtype='float32'),
                                  btf.Stack(dtype='float32'))
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                 btf.Stack(dtype='float32'),
                                 btf.Stack(dtype='float32'), btf.Stack())
    target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
    target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
    if args.bucket_scheme == 'constant':
        bucket_scheme = ConstWidthBucket()
    elif args.bucket_scheme == 'linear':
        bucket_scheme = LinearWidthBucket()
    elif args.bucket_scheme == 'exp':
        bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
    else:
        raise NotImplementedError
    train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths,
                                             batch_size=args.batch_size,
                                             num_buckets=args.num_buckets,
                                             ratio=args.bucket_ratio,
                                             shuffle=True,
                                             use_average_length=True,
                                             num_shards=len(ctx),
                                             bucket_scheme=bucket_scheme)
    logging.info('Train Batch Sampler:\n{}'.format(
        train_batch_sampler.stats()))
    train_data_loader = ShardedDataLoader(data_train,
                                          batch_sampler=train_batch_sampler,
                                          batchify_fn=train_batchify_fn,
                                          num_workers=8)

    val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths,
                                           batch_size=args.test_batch_size,
                                           num_buckets=args.num_buckets,
                                           ratio=args.bucket_ratio,
                                           shuffle=False,
                                           use_average_length=True,
                                           bucket_scheme=bucket_scheme)
    logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
    val_data_loader = DataLoader(data_val,
                                 batch_sampler=val_batch_sampler,
                                 batchify_fn=test_batchify_fn,
                                 num_workers=8)
    test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths,
                                            batch_size=args.test_batch_size,
                                            num_buckets=args.num_buckets,
                                            ratio=args.bucket_ratio,
                                            shuffle=False,
                                            use_average_length=True,
                                            bucket_scheme=bucket_scheme)
    logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
    test_data_loader = DataLoader(data_test,
                                  batch_sampler=test_batch_sampler,
                                  batchify_fn=test_batchify_fn,
                                  num_workers=8)

    if args.bleu == 'tweaked':
        bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY')
        split_compound_word = bpe
        tokenized = True
    elif args.bleu == '13a' or args.bleu == 'intl':
        bpe = False
        split_compound_word = False
        tokenized = False
    else:
        raise NotImplementedError

    best_valid_bleu = 0.0
    step_num = 0
    warmup_steps = args.warmup_steps
    grad_interval = args.num_accumulated
    model.collect_params().setattr('grad_req', 'add')
    average_start = (len(train_data_loader) //
                     grad_interval) * (args.epochs - args.average_start)
    average_param_dict = None
    model.collect_params().zero_grad()
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_wc = 0
        loss_denom = 0
        step_loss = 0
        log_start_time = time.time()
        for batch_id, seqs \
                in enumerate(train_data_loader):
            if batch_id % grad_interval == 0:
                step_num += 1
                new_lr = args.lr / math.sqrt(args.num_units) \
                         * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5))
                trainer.set_learning_rate(new_lr)
            src_wc, tgt_wc, bs = np.sum(
                [(shard[2].sum(), shard[3].sum(), shard[0].shape[0])
                 for shard in seqs],
                axis=0)
            src_wc = src_wc.asscalar()
            tgt_wc = tgt_wc.asscalar()
            loss_denom += tgt_wc - bs
            seqs = [[seq.as_in_context(context) for seq in shard]
                    for context, shard in zip(ctx, seqs)]
            Ls = []
            with mx.autograd.record():
                for src_seq, tgt_seq, src_valid_length, tgt_valid_length in seqs:
                    out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length,
                                   tgt_valid_length - 1)
                    smoothed_label = label_smoothing(tgt_seq[:, 1:])
                    ls = loss_function(out, smoothed_label,
                                       tgt_valid_length - 1).sum()
                    Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size /
                              100.0)
            for L in Ls:
                L.backward()
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                if average_param_dict is None:
                    average_param_dict = {
                        k: v.data(ctx[0]).copy()
                        for k, v in model.collect_params().items()
                    }
                trainer.step(float(loss_denom) / args.batch_size / 100.0)
                param_dict = model.collect_params()
                param_dict.zero_grad()
                if step_num > average_start:
                    alpha = 1. / max(1, step_num - average_start)
                    for name, average_param in average_param_dict.items():
                        average_param[:] += alpha * (
                            param_dict[name].data(ctx[0]) - average_param)
            step_loss += sum([L.asscalar() for L in Ls])
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0
                loss_denom = 0
                step_loss = 0
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % (args.log_interval * grad_interval) == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'.format(
                                 epoch_id, batch_id + 1,
                                 len(train_data_loader),
                                 log_avg_loss / args.log_interval,
                                 np.exp(log_avg_loss / args.log_interval),
                                 wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_wc = 0
        mx.nd.waitall()
        valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
        valid_bleu_score, _, _, _, _ = compute_bleu(
            [val_tgt_sentences],
            valid_translation_out,
            tokenized=tokenized,
            tokenizer=args.bleu,
            split_compound_word=split_compound_word,
            bpe=bpe)
        logging.info(
            '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
            .format(epoch_id, valid_loss, np.exp(valid_loss),
                    valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
        test_bleu_score, _, _, _, _ = compute_bleu(
            [test_tgt_sentences],
            test_translation_out,
            tokenized=tokenized,
            tokenizer=args.bleu,
            split_compound_word=split_compound_word,
            bpe=bpe)
        logging.info(
            '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
            format(epoch_id, test_loss, np.exp(test_loss),
                   test_bleu_score * 100))
        write_sentences(
            valid_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_valid_out.txt').format(epoch_id))
        write_sentences(
            test_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_parameters(save_path)
        save_path = os.path.join(args.save_dir,
                                 'epoch{:d}.params'.format(epoch_id))
        model.save_parameters(save_path)
    save_path = os.path.join(args.save_dir, 'average.params')
    mx.nd.save(save_path, average_param_dict)
    if args.average_checkpoint:
        for j in range(args.num_averages):
            params = mx.nd.load(
                os.path.join(args.save_dir,
                             'epoch{:d}.params'.format(args.epochs - j - 1)))
            alpha = 1. / (j + 1)
            for k, v in model._collect_params_with_prefix().items():
                for c in ctx:
                    v.data(c)[:] += alpha * (params[k].as_in_context(c) -
                                             v.data(c))
        save_path = os.path.join(
            args.save_dir,
            'average_checkpoint_{}.params'.format(args.num_averages))
        model.save_parameters(save_path)
    elif args.average_start > 0:
        for k, v in model.collect_params().items():
            v.set_data(average_param_dict[k])
        save_path = os.path.join(args.save_dir, 'average.params')
        model.save_parameters(save_path)
    else:
        model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'),
                              ctx)
    valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
    valid_bleu_score, _, _, _, _ = compute_bleu(
        [val_tgt_sentences],
        valid_translation_out,
        tokenized=tokenized,
        tokenizer=args.bleu,
        bpe=bpe,
        split_compound_word=split_compound_word)
    logging.info(
        'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'.
        format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
    test_bleu_score, _, _, _, _ = compute_bleu(
        [test_tgt_sentences],
        test_translation_out,
        tokenized=tokenized,
        tokenizer=args.bleu,
        bpe=bpe,
        split_compound_word=split_compound_word)
    logging.info(
        'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
        format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(args.save_dir, 'best_valid_out.txt'))
    write_sentences(test_translation_out,
                    os.path.join(args.save_dir, 'best_test_out.txt'))
示例#24
0
def train(args):
    store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    src_tokenizer = create_tokenizer(args.src_tokenizer,
                                     args.src_subword_model_path,
                                     args.src_vocab_path)
    tgt_tokenizer = create_tokenizer(args.tgt_tokenizer,
                                     args.tgt_subword_model_path,
                                     args.tgt_vocab_path)
    src_vocab = src_tokenizer.vocab
    tgt_vocab = tgt_tokenizer.vocab
    train_src_data, train_tgt_data = load_dataset_with_cache(
        args.train_src_corpus, args.train_tgt_corpus, src_tokenizer,
        tgt_tokenizer, args.overwrite_cache)
    dev_src_data, dev_tgt_data = load_dataset_with_cache(
        args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer,
        args.overwrite_cache)
    data_train = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))
    ])
    data_val = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))
    ])
    # Construct the model + loss function
    if args.cfg.endswith('.yml'):
        cfg = TransformerModel.get_cfg().clone_merge(args.cfg)
    else:
        cfg = TransformerModel.get_cfg(args.cfg)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = len(src_vocab)
    cfg.MODEL.tgt_vocab_size = len(tgt_vocab)
    if args.fp16:
        raise NotImplementedError


#        cfg.MODEL.dtype = 'float16'
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l)
    model.hybridize()
    if local_rank == 0:
        logging.info(model)
    with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f:
        cfg_f.write(cfg.dump())
    label_smooth_loss = LabelSmoothCrossEntropyLoss(
        num_labels=len(tgt_vocab),
        alpha=args.label_smooth_alpha,
        from_logits=False)
    label_smooth_loss.hybridize()
    rescale_loss = 100.0

    if args.comm_backend == 'horovod':
        hvd.broadcast_parameters(model.collect_params(), root_rank=0)

    # Construct the trainer
    # TODO(sxjscience) Support AMP
    if args.lr is None:
        base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt(
            args.warmup_steps)
    else:
        base_lr = args.lr
    lr_scheduler = InverseSquareRootScheduler(
        warmup_steps=args.warmup_steps,
        base_lr=base_lr,
        warmup_init_lr=args.warmup_init_lr)
    trainer_settings = (model.collect_params(), 'adam', {
        'learning_rate': args.lr,
        'beta1': 0.9,
        'beta2': 0.98,
        'epsilon': 1e-9,
        'lr_scheduler': lr_scheduler
    })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(*trainer_settings)
    else:
        trainer = gluon.Trainer(*trainer_settings)
    # Load Data
    if args.sampler == 'BoundedBudgetSampler':
        train_batch_sampler = BoundedBudgetSampler(
            lengths=[(ele[2], ele[3]) for ele in data_train],
            max_num_tokens=args.max_num_tokens,
            max_num_sentences=args.max_num_sentences,
            seed=args.seed,
            num_parts=num_parts,
            part_index=rank)
    elif args.sampler == 'FixedBucketSampler':
        if args.comm_backend == 'horovod':
            raise NotImplementedError(
                'FixedBucketSampler does not support horovod at present')

        if args.bucket_scheme == 'constant':
            bucket_scheme = ConstWidthBucket()
        elif args.bucket_scheme == 'linear':
            bucket_scheme = LinearWidthBucket()
        elif args.bucket_scheme == 'exp':
            bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
        else:
            raise NotImplementedError
        # TODO(sxjscience) Support auto-bucket-size tuning
        train_batch_sampler = FixedBucketSampler(lengths=[
            (ele[2], ele[3]) for ele in data_train
        ],
                                                 batch_size=args.batch_size,
                                                 num_buckets=args.num_buckets,
                                                 ratio=args.bucket_ratio,
                                                 shuffle=True,
                                                 use_average_length=True,
                                                 bucket_scheme=bucket_scheme,
                                                 seed=args.seed)
    else:
        raise NotImplementedError

    if local_rank == 0:
        logging.info(train_batch_sampler)

    batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(),
                           bf.Stack())
    train_data_loader = gluon.data.DataLoader(
        data_train,
        batch_sampler=train_batch_sampler,
        batchify_fn=batchify_fn,
        num_workers=0)

    val_data_loader = gluon.data.DataLoader(data_val,
                                            batch_size=args.val_batch_size,
                                            batchify_fn=batchify_fn,
                                            num_workers=0,
                                            shuffle=False)
    for v in model.collect_params().values():
        if v.grad_req != 'null':
            v.grad_req = 'add'
    model.zero_grad()
    model_averager = AverageSGDTracker(model.collect_params())
    log_start_time = time.time()
    num_params, num_fixed_params = None, None
    # TODO(sxjscience) Add a log metric class
    accum_count = 0
    loss_denom = 0
    n_train_iters = 0
    log_wc = 0
    log_avg_loss = 0.0
    log_loss_denom = 0
    epoch_id = 0
    while (args.epochs < 0 or epoch_id < args.epochs
           ):  # when args.epochs < 0, the model will keep training
        n_epoch_train_iters = 0
        processed_batch_num = 0
        train_multi_data_loader = grouper(train_data_loader, len(ctx_l))
        is_last_batch = False
        sample_data_l = next(train_multi_data_loader)
        while not is_last_batch:
            processed_batch_num += len(sample_data_l)
            loss_l = []
            for sample_data, ctx in zip(sample_data_l, ctx_l):
                if sample_data is None:
                    continue
                src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data
                src_wc, tgt_wc, bs = src_valid_length.sum(
                ), tgt_valid_length.sum(), src_token_ids.shape[0]
                loss_denom += tgt_wc - bs
                log_loss_denom += tgt_wc - bs
                log_wc += src_wc + tgt_wc
                src_token_ids = src_token_ids.as_in_ctx(ctx)
                tgt_token_ids = tgt_token_ids.as_in_ctx(ctx)
                src_valid_length = src_valid_length.as_in_ctx(ctx)
                tgt_valid_length = tgt_valid_length.as_in_ctx(ctx)
                with mx.autograd.record():
                    tgt_pred = model(src_token_ids, src_valid_length,
                                     tgt_token_ids[:, :-1],
                                     tgt_valid_length - 1)
                    tgt_labels = tgt_token_ids[:, 1:]
                    loss = label_smooth_loss(tgt_pred, tgt_labels)
                    loss = mx.npx.sequence_mask(
                        loss,
                        sequence_length=tgt_valid_length - 1,
                        use_sequence_length=True,
                        axis=1)
                    loss_l.append(loss.sum() / rescale_loss)
            for l in loss_l:
                l.backward()
            accum_count += 1
            try:
                sample_data_l = next(train_multi_data_loader)
            except StopIteration:
                is_last_batch = True
            if local_rank == 0 and num_params is None:
                num_params, num_fixed_params = count_parameters(
                    model.collect_params())
                logging.info(
                    'Total Number of Parameters (not-fixed/fixed): {}/{}'.
                    format(num_params, num_fixed_params))
            sum_loss = sum([l.as_in_ctx(mx.cpu())
                            for l in loss_l]) * rescale_loss
            log_avg_loss += sum_loss
            mx.npx.waitall()
            if accum_count == args.num_accumulated or is_last_batch:
                # Update the parameters
                n_train_iters += 1
                n_epoch_train_iters += 1
                trainer.step(loss_denom.asnumpy() / rescale_loss)
                accum_count = 0
                loss_denom = 0
                model.zero_grad()
                if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \
                   (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update):
                    model_averager.step()
                if local_rank == 0 and \
                   (n_epoch_train_iters % args.log_interval == 0 or is_last_batch):
                    log_end_time = time.time()
                    log_wc = log_wc.asnumpy()
                    wps = log_wc / (log_end_time - log_start_time)
                    log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy()
                    logging.info(
                        '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                        'throughput={:.2f}K wps, wc={:.2f}K, LR={}'.format(
                            epoch_id, processed_batch_num * num_parts,
                            len(train_data_loader), log_avg_loss,
                            np.exp(log_avg_loss), wps / 1000, log_wc / 1000,
                            trainer.learning_rate))
                    log_start_time = time.time()
                    log_avg_loss = 0
                    log_loss_denom = 0
                    log_wc = 0
                if local_rank == 0 and \
                   (args.max_update > 0 and n_train_iters % args.save_interval_update == 0):
                    model.save_parameters(os.path.join(
                        args.save_dir, 'update{:d}.params'.format(
                            n_train_iters // args.save_interval_update)),
                                          deduplicate=True)
                if args.max_update > 0 and n_train_iters >= args.max_update:
                    break
        if local_rank == 0 and args.epochs > 0:
            model.save_parameters(os.path.join(
                args.save_dir, 'epoch{:d}.params'.format(epoch_id)),
                                  deduplicate=True)
        avg_valid_loss = validation(model, val_data_loader, ctx_l)
        logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'.format(
            epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

        if args.max_update > 0 and n_train_iters >= args.max_update:
            break
        epoch_id += 1

    if args.num_averages > 0:
        model_averager.copy_back(
            model.collect_params())  # TODO(sxjscience) Rewrite using update
        model.save_parameters(os.path.join(args.save_dir, 'average.params'),
                              deduplicate=True)