示例#1
0
 def data_loader(self,
                 docs: Sequence[Document],
                 batch_size,
                 shuffle=False,
                 label=True,
                 **kwargs) -> DataLoader:
     if label is True and self.label_map is None:
         raise ValueError('Please specify label_map')
     batchify_fn = kwargs.get('batchify_fn', sequence_batchify_fn)
     bucket = kwargs.get('bucket', False)
     num_buckets = kwargs.get('num_buckets', 10)
     ratio = kwargs.get('ratio', 0)
     dataset = SequencesDataset(docs=docs,
                                embs=self.embs,
                                key=self.key,
                                label_map=self.label_map,
                                label=label)
     if bucket is True:
         dataset_lengths = list(map(lambda x: float(len(x[3])), dataset))
         batch_sampler = FixedBucketSampler(dataset_lengths,
                                            batch_size=batch_size,
                                            num_buckets=num_buckets,
                                            ratio=ratio,
                                            shuffle=shuffle)
         return DataLoader(dataset=dataset,
                           batch_sampler=batch_sampler,
                           batchify_fn=batchify_fn)
     else:
         return DataLoader(dataset=dataset,
                           batch_size=batch_size,
                           shuffle=shuffle,
                           batchify_fn=batchify_fn)
def test_sharded_data_loader_record_file():
    # test record file
    url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/{}'
    filename = 'val.rec'
    idx_filename = 'val.idx'
    download(url_format.format(filename),
             path=os.path.join('tests', 'data', filename))
    download(url_format.format(idx_filename),
             path=os.path.join('tests', 'data', idx_filename))
    rec_dataset = gluon.data.vision.ImageRecordDataset(
        os.path.join('tests', 'data', filename))

    num_workers = 2
    num_shards = 4
    X = np.random.uniform(size=(100, 20))
    Y = np.random.uniform(size=(100, ))
    batch_sampler = FixedBucketSampler(lengths=[X.shape[1]] * X.shape[0],
                                       batch_size=2,
                                       num_buckets=1,
                                       shuffle=False,
                                       num_shards=num_shards)
    loader = ShardedDataLoader(rec_dataset,
                               batch_sampler=batch_sampler,
                               num_workers=num_workers)
    for i, seqs in enumerate(loader):
        assert len(seqs) == num_shards
def test_dataloader():
    batch_size = 128
    dataset = NLPCCDataset('train', data_root_dir)
    transformer = DataTransformer(['train'])

    word_vocab = transformer._word_vocab
    transformed_dataset = dataset.transform(transformer, lazy=False)

    batchify_fn = Tuple(
        Stack(),
        Pad(axis=0,
            pad_val=word_vocab[word_vocab.padding_token],
            ret_length=True), Stack())
    sampler = FixedBucketSampler(
        lengths=[len(item[1]) for item in transformed_dataset],
        batch_size=batch_size,
        shuffle=True,
        num_buckets=30)

    data_loader = DataLoader(transformed_dataset,
                             batchify_fn=batchify_fn,
                             batch_sampler=sampler)

    for i, (rec_id, (data, original_length), label) in enumerate(data_loader):
        print(data.shape)
        assert data.shape[0] <= 128
示例#4
0
def transform(raw_data, params):
    # 定义数据转换接口
    # raw_data --> batch_data

    for data in raw_data:
        pass

    num_buckets = params.num_buckets
    batch_size = params.batch_size

    responses = raw_data

    batch_idxes = FixedBucketSampler([len(rs) for rs in responses],
                                     batch_size,
                                     num_buckets=num_buckets)
    batch = []

    def index(r):
        correct = 0 if r[1] <= 0 else 1
        return r[0] * 2 + correct

    for batch_idx in tqdm(batch_idxes, "batchify"):
        batch_rs = []
        batch_pick_index = []
        batch_labels = []
        for idx in batch_idx:
            batch_rs.append([index(r) for r in responses[idx]])
            if len(responses[idx]) <= 1:
                pick_index, labels = [], []
            else:
                pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1)
                                           for r in responses[idx][1:]])
            batch_pick_index.append(list(pick_index))
            batch_labels.append(list(labels))

        max_len = max([len(rs) for rs in batch_rs])
        padder = PadSequence(max_len, pad_val=0)
        batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs])

        max_len = max([len(rs) for rs in batch_labels])
        padder = PadSequence(max_len, pad_val=0)
        batch_labels, label_mask = zip(*[(padder(labels), len(labels))
                                         for labels in batch_labels])
        batch_pick_index = [
            padder(pick_index) for pick_index in batch_pick_index
        ]
        batch.append([
            mx.nd.array(batch_rs),
            mx.nd.array(data_mask),
            mx.nd.array(batch_labels),
            mx.nd.array(batch_pick_index),
            mx.nd.array(label_mask)
        ])

    return batch
示例#5
0
def transform(raw_data, params):
    # 定义数据转换接口
    # raw_data --> batch_data

    num_buckets = params.num_buckets
    batch_size = params.batch_size

    responses = raw_data

    # 不同长度的样本每次都会分配到固定的bucket当中
    batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets)
    batch = []

    def index(r):
        correct = 0 if r[1] <= 0 else 1
        return r[0] * 2 + correct

    for batch_idx in tqdm(batch_idxes, "batchify"):
        batch_rs = []
        batch_pick_index = []
        batch_labels = []
        for idx in batch_idx:
            # 1. 答对答错存储的id不一样
            # 2. batch_pick_index[i], batch_labels[i]维度比batch_rs[i]小1
            batch_rs.append([index(r) for r in responses[idx]])
            if len(responses[idx]) <= 1:
                pick_index, labels = [], []
            else:
                pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]])
            batch_pick_index.append(list(pick_index))
            batch_labels.append(list(labels))
        # 每一个bucket中的max_len不一样
        max_len = max([len(rs) for rs in batch_rs])
        padder = PadSequence(max_len, pad_val=0)
        # batch_rs padding到max_len之后的数据
        # data_mask表示对应的有效数据的条数
        batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs])

        max_len = max([len(rs) for rs in batch_labels])
        padder = PadSequence(max_len, pad_val=0)
        batch_labels, label_mask = zip(*[(padder(labels), len(labels)) for labels in batch_labels])
        batch_pick_index = [padder(pick_index) for pick_index in batch_pick_index]
        # Load
        # 所有数据都padding到固定长度的序列
        # data_mask label_mask表示有效数据的长度
        # len(batch_labels) + 1 = len(batch_rs)
        batch.append(
            [torch.tensor(batch_rs), torch.tensor(data_mask), torch.tensor(batch_labels),
             torch.tensor(batch_pick_index),
             torch.tensor(label_mask)])

    return batch
 def get_dataloader(dataset):
     """create data loader based on the dataset chunk"""
     t0 = time.time()
     lengths = dataset.get_field('valid_lengths')
     logging.debug('Num samples = %d', len(lengths))
     # A batch includes: input_id, masked_id, masked_position, masked_weight,
     #                   next_sentence_label, segment_id, valid_length
     batchify_fn = Tuple(Pad(), Pad(), Pad(), Pad(), Stack(), Pad(),
                         Stack())
     if args.by_token:
         # sharded data loader
         sampler = nlp.data.FixedBucketSampler(
             lengths=lengths,
             # batch_size per shard
             batch_size=batch_size,
             num_buckets=args.num_buckets,
             shuffle=is_train,
             use_average_length=True,
             num_shards=num_ctxes)
         dataloader = nlp.data.ShardedDataLoader(dataset,
                                                 batch_sampler=sampler,
                                                 batchify_fn=batchify_fn,
                                                 num_workers=num_ctxes)
         logging.debug('Batch Sampler:\n%s', sampler.stats())
     else:
         sampler = FixedBucketSampler(lengths,
                                      batch_size=batch_size * num_ctxes,
                                      num_buckets=args.num_buckets,
                                      ratio=0,
                                      shuffle=is_train)
         dataloader = DataLoader(dataset=dataset,
                                 batch_sampler=sampler,
                                 batchify_fn=batchify_fn,
                                 num_workers=1)
         logging.debug('Batch Sampler:\n%s', sampler.stats())
     t1 = time.time()
     logging.debug('Dataloader creation cost = %.2f s', t1 - t0)
     return dataloader
示例#7
0
文件: etl.py 项目: tswsxk/XKT
def transform(raw_data, params):
    # 定义数据转换接口
    # raw_data --> batch_data

    num_buckets = params.num_buckets
    batch_size = params.batch_size

    responses = raw_data

    batch_idxes = FixedBucketSampler([len(rs) for rs in responses],
                                     batch_size,
                                     num_buckets=num_buckets)
    batch = []

    def response_index(r):
        correct = 0 if r[1] <= 0 else 1
        return r[0] * 2 + correct

    def question_index(r):
        return r[0]

    for batch_idx in tqdm(batch_idxes, "batchify"):
        batch_qs = []
        batch_rs = []
        batch_labels = []
        for idx in batch_idx:
            batch_qs.append([question_index(r) for r in responses[idx]])
            batch_rs.append([response_index(r) for r in responses[idx]])
            labels = [0 if r[1] <= 0 else 1 for r in responses[idx][:]]
            batch_labels.append(list(labels))

        max_len = max([len(rs) for rs in batch_rs])
        padder = PadSequence(max_len, pad_val=0)
        batch_qs, _ = zip(*[(padder(qs), len(qs)) for qs in batch_qs])
        batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs])

        max_len = max([len(rs) for rs in batch_labels])
        padder = PadSequence(max_len, pad_val=0)
        batch_labels, label_mask = zip(*[(padder(labels), len(labels))
                                         for labels in batch_labels])
        batch.append([
            mx.nd.array(batch_qs, dtype="float32"),
            mx.nd.array(batch_rs, dtype="float32"),
            mx.nd.array(data_mask),
            mx.nd.array(batch_labels),
            mx.nd.array(label_mask)
        ])

    return batch
def test_sharded_data_loader():
    X = np.random.uniform(size=(100, 20))
    Y = np.random.uniform(size=(100, ))
    dataset = gluon.data.ArrayDataset(X, Y)
    loader = ShardedDataLoader(dataset, 2)
    for i, (x, y) in enumerate(loader):
        assert mx.test_utils.almost_equal(x.asnumpy(), X[i * 2:(i + 1) * 2])
        assert mx.test_utils.almost_equal(y.asnumpy(), Y[i * 2:(i + 1) * 2])
    num_shards = 4
    batch_sampler = FixedBucketSampler(lengths=[X.shape[1]] * X.shape[0],
                                       batch_size=2,
                                       num_buckets=1,
                                       shuffle=False,
                                       num_shards=num_shards)
    for thread_pool in [True, False]:
        for num_workers in [0, 1, 2, 3, 4]:
            loader = ShardedDataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       thread_pool=thread_pool)
            for i, seqs in enumerate(loader):
                assert len(seqs) == num_shards
                for j in range(num_shards):
                    if i != len(loader) - 1:
                        assert mx.test_utils.almost_equal(
                            seqs[j][0].asnumpy(),
                            X[(i * num_shards + j) *
                              2:(i * num_shards + j + 1) * 2])
                        assert mx.test_utils.almost_equal(
                            seqs[j][1].asnumpy(),
                            Y[(i * num_shards + j) *
                              2:(i * num_shards + j + 1) * 2])
                    else:
                        assert mx.test_utils.almost_equal(
                            seqs[j][0].asnumpy(),
                            X[(i * num_shards + j) * 2 -
                              num_shards:(i * num_shards + j + 1) * 2 -
                              num_shards])
                        assert mx.test_utils.almost_equal(
                            seqs[j][1].asnumpy(),
                            Y[(i * num_shards + j) * 2 -
                              num_shards:(i * num_shards + j + 1) * 2 -
                              num_shards])
示例#9
0
def transform_segment(transformer, segment, options):
    dataset = NLPCCDataset(segment, './data/')
    transformed_dataset = dataset.transform(transformer, lazy=False)

    word_vocab = transformer.get_word_vocab()

    batchify_fn = Tuple(
        Stack(),
        Pad(axis=0,
            pad_val=word_vocab[word_vocab.padding_token],
            ret_length=True), Stack())

    sampler = FixedBucketSampler(
        lengths=[len(item[1]) for item in transformed_dataset],
        batch_size=options.batch_size,
        shuffle=True,
        num_buckets=options.num_buckets)
    return DataLoader(transformed_dataset,
                      batchify_fn=batchify_fn,
                      batch_sampler=sampler)
示例#10
0
文件: etl.py 项目: tswsxk/CangJie
def transform(raw_data, params):
    # 定义数据转换接口
    # raw_data --> batch_data

    batch_size = params.batch_size
    padding = params.padding
    num_buckets = params.num_buckets
    fixed_length = params.fixed_length

    features, labels = raw_data
    word_feature, word_radical_feature, char_feature, char_radical_feature = features
    batch_idxes = FixedBucketSampler([len(word_f) for word_f in word_feature],
                                     batch_size,
                                     num_buckets=num_buckets)
    batch = []
    for batch_idx in batch_idxes:
        batch_features = [[] for _ in range(len(features))]
        batch_labels = []
        for idx in batch_idx:
            for i, feature in enumerate(batch_features):
                batch_features[i].append(features[i][idx])
            batch_labels.append(labels[idx])
        batch_data = []
        word_mask = []
        char_mask = []
        for i, feature in enumerate(batch_features):
            max_len = max([len(fea) for fea in feature
                           ]) if not fixed_length else fixed_length
            padder = PadSequence(max_len, pad_val=padding)
            feature, mask = zip(*[(padder(fea), len(fea)) for fea in feature])
            if i == 0:
                word_mask = mask
            elif i == 2:
                char_mask = mask
            batch_data.append(mx.nd.array(feature))
        batch_data.append(mx.nd.array(word_mask))
        batch_data.append(mx.nd.array(char_mask))
        batch_data.append(mx.nd.array(batch_labels, dtype=np.int))
        batch.append(batch_data)
    return batch[::-1]
示例#11
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer, {
        'learning_rate': args.lr,
        'beta2': 0.98,
        'epsilon': 1e-9
    })

    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                  btf.Stack(dtype='float32'),
                                  btf.Stack(dtype='float32'))
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                 btf.Stack(dtype='float32'),
                                 btf.Stack(dtype='float32'), btf.Stack())
    target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
    target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
    if args.bucket_scheme == 'constant':
        bucket_scheme = ConstWidthBucket()
    elif args.bucket_scheme == 'linear':
        bucket_scheme = LinearWidthBucket()
    elif args.bucket_scheme == 'exp':
        bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
    else:
        raise NotImplementedError
    train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths,
                                             batch_size=args.batch_size,
                                             num_buckets=args.num_buckets,
                                             ratio=args.bucket_ratio,
                                             shuffle=True,
                                             use_average_length=True,
                                             num_shards=len(ctx),
                                             bucket_scheme=bucket_scheme)
    logging.info('Train Batch Sampler:\n{}'.format(
        train_batch_sampler.stats()))
    train_data_loader = ShardedDataLoader(data_train,
                                          batch_sampler=train_batch_sampler,
                                          batchify_fn=train_batchify_fn,
                                          num_workers=8)

    val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths,
                                           batch_size=args.test_batch_size,
                                           num_buckets=args.num_buckets,
                                           ratio=args.bucket_ratio,
                                           shuffle=False,
                                           use_average_length=True,
                                           bucket_scheme=bucket_scheme)
    logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
    val_data_loader = DataLoader(data_val,
                                 batch_sampler=val_batch_sampler,
                                 batchify_fn=test_batchify_fn,
                                 num_workers=8)
    test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths,
                                            batch_size=args.test_batch_size,
                                            num_buckets=args.num_buckets,
                                            ratio=args.bucket_ratio,
                                            shuffle=False,
                                            use_average_length=True,
                                            bucket_scheme=bucket_scheme)
    logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
    test_data_loader = DataLoader(data_test,
                                  batch_sampler=test_batch_sampler,
                                  batchify_fn=test_batchify_fn,
                                  num_workers=8)

    if args.bleu == 'tweaked':
        bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY')
        split_compound_word = bpe
        tokenized = True
    elif args.bleu == '13a' or args.bleu == 'intl':
        bpe = False
        split_compound_word = False
        tokenized = False
    else:
        raise NotImplementedError

    best_valid_bleu = 0.0
    step_num = 0
    warmup_steps = args.warmup_steps
    grad_interval = args.num_accumulated
    model.collect_params().setattr('grad_req', 'add')
    average_start = (len(train_data_loader) //
                     grad_interval) * (args.epochs - args.average_start)
    average_param_dict = None
    model.collect_params().zero_grad()
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_wc = 0
        loss_denom = 0
        step_loss = 0
        log_start_time = time.time()
        for batch_id, seqs \
                in enumerate(train_data_loader):
            if batch_id % grad_interval == 0:
                step_num += 1
                new_lr = args.lr / math.sqrt(args.num_units) \
                         * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5))
                trainer.set_learning_rate(new_lr)
            src_wc, tgt_wc, bs = np.sum(
                [(shard[2].sum(), shard[3].sum(), shard[0].shape[0])
                 for shard in seqs],
                axis=0)
            src_wc = src_wc.asscalar()
            tgt_wc = tgt_wc.asscalar()
            loss_denom += tgt_wc - bs
            seqs = [[seq.as_in_context(context) for seq in shard]
                    for context, shard in zip(ctx, seqs)]
            Ls = []
            with mx.autograd.record():
                for src_seq, tgt_seq, src_valid_length, tgt_valid_length in seqs:
                    out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length,
                                   tgt_valid_length - 1)
                    smoothed_label = label_smoothing(tgt_seq[:, 1:])
                    ls = loss_function(out, smoothed_label,
                                       tgt_valid_length - 1).sum()
                    Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size /
                              100.0)
            for L in Ls:
                L.backward()
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                if average_param_dict is None:
                    average_param_dict = {
                        k: v.data(ctx[0]).copy()
                        for k, v in model.collect_params().items()
                    }
                trainer.step(float(loss_denom) / args.batch_size / 100.0)
                param_dict = model.collect_params()
                param_dict.zero_grad()
                if step_num > average_start:
                    alpha = 1. / max(1, step_num - average_start)
                    for name, average_param in average_param_dict.items():
                        average_param[:] += alpha * (
                            param_dict[name].data(ctx[0]) - average_param)
            step_loss += sum([L.asscalar() for L in Ls])
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0
                loss_denom = 0
                step_loss = 0
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % (args.log_interval * grad_interval) == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'.format(
                                 epoch_id, batch_id + 1,
                                 len(train_data_loader),
                                 log_avg_loss / args.log_interval,
                                 np.exp(log_avg_loss / args.log_interval),
                                 wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_wc = 0
        mx.nd.waitall()
        valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
        valid_bleu_score, _, _, _, _ = compute_bleu(
            [val_tgt_sentences],
            valid_translation_out,
            tokenized=tokenized,
            tokenizer=args.bleu,
            split_compound_word=split_compound_word,
            bpe=bpe)
        logging.info(
            '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
            .format(epoch_id, valid_loss, np.exp(valid_loss),
                    valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
        test_bleu_score, _, _, _, _ = compute_bleu(
            [test_tgt_sentences],
            test_translation_out,
            tokenized=tokenized,
            tokenizer=args.bleu,
            split_compound_word=split_compound_word,
            bpe=bpe)
        logging.info(
            '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
            format(epoch_id, test_loss, np.exp(test_loss),
                   test_bleu_score * 100))
        write_sentences(
            valid_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_valid_out.txt').format(epoch_id))
        write_sentences(
            test_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_parameters(save_path)
        save_path = os.path.join(args.save_dir,
                                 'epoch{:d}.params'.format(epoch_id))
        model.save_parameters(save_path)
    save_path = os.path.join(args.save_dir, 'average.params')
    mx.nd.save(save_path, average_param_dict)
    if args.average_checkpoint:
        for j in range(args.num_averages):
            params = mx.nd.load(
                os.path.join(args.save_dir,
                             'epoch{:d}.params'.format(args.epochs - j - 1)))
            alpha = 1. / (j + 1)
            for k, v in model._collect_params_with_prefix().items():
                for c in ctx:
                    v.data(c)[:] += alpha * (params[k].as_in_context(c) -
                                             v.data(c))
        save_path = os.path.join(
            args.save_dir,
            'average_checkpoint_{}.params'.format(args.num_averages))
        model.save_parameters(save_path)
    elif args.average_start > 0:
        for k, v in model.collect_params().items():
            v.set_data(average_param_dict[k])
        save_path = os.path.join(args.save_dir, 'average.params')
        model.save_parameters(save_path)
    else:
        model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'),
                              ctx)
    valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
    valid_bleu_score, _, _, _, _ = compute_bleu(
        [val_tgt_sentences],
        valid_translation_out,
        tokenized=tokenized,
        tokenizer=args.bleu,
        bpe=bpe,
        split_compound_word=split_compound_word)
    logging.info(
        'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'.
        format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
    test_bleu_score, _, _, _, _ = compute_bleu(
        [test_tgt_sentences],
        test_translation_out,
        tokenized=tokenized,
        tokenizer=args.bleu,
        bpe=bpe,
        split_compound_word=split_compound_word)
    logging.info(
        'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
        format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(args.save_dir, 'best_valid_out.txt'))
    write_sentences(test_translation_out,
                    os.path.join(args.save_dir, 'best_test_out.txt'))
示例#12
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer, {
        'learning_rate': args.lr,
        'beta2': 0.98,
        'epsilon': 1e-9
    })

    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(),
                                  btf.Stack())
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(),
                                 btf.Stack(), btf.Stack())
    target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
    target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
    train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths,
                                             batch_size=args.batch_size,
                                             num_buckets=args.num_buckets,
                                             ratio=args.bucket_ratio,
                                             shuffle=True,
                                             use_average_length=True)
    logging.info('Train Batch Sampler:\n{}'.format(
        train_batch_sampler.stats()))
    train_data_loader = DataLoader(data_train,
                                   batch_sampler=train_batch_sampler,
                                   batchify_fn=train_batchify_fn,
                                   num_workers=8)

    val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths,
                                           batch_size=args.test_batch_size,
                                           num_buckets=args.num_buckets,
                                           ratio=args.bucket_ratio,
                                           shuffle=False,
                                           use_average_length=True)
    logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
    val_data_loader = DataLoader(data_val,
                                 batch_sampler=val_batch_sampler,
                                 batchify_fn=test_batchify_fn,
                                 num_workers=8)
    test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths,
                                            batch_size=args.test_batch_size,
                                            num_buckets=args.num_buckets,
                                            ratio=args.bucket_ratio,
                                            shuffle=False,
                                            use_average_length=True)
    logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
    test_data_loader = DataLoader(data_test,
                                  batch_sampler=test_batch_sampler,
                                  batchify_fn=test_batchify_fn,
                                  num_workers=8)
    best_valid_bleu = 0.0
    step_num = 0
    warmup_steps = args.warmup_steps
    grad_interval = args.num_accumulated
    model.collect_params().setattr('grad_req', 'add')
    average_start = (len(train_data_loader) //
                     grad_interval) * (args.epochs - args.average_start)
    average_param_dict = None
    model.collect_params().zero_grad()
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_wc = 0
        loss_denom = 0
        step_loss = 0
        log_start_time = time.time()
        for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length) \
                in enumerate(train_data_loader):
            if batch_id % grad_interval == 0:
                step_num += 1
                new_lr = args.lr / math.sqrt(args.num_units) \
                         * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5))
                trainer.set_learning_rate(new_lr)
            # logging.info(src_seq.context) Context suddenly becomes GPU.
            src_wc = src_valid_length.sum().asscalar()
            tgt_wc = tgt_valid_length.sum().asscalar()
            loss_denom += tgt_wc - tgt_valid_length.shape[0]
            if src_seq.shape[0] > len(ctx):
                src_seq_list, tgt_seq_list, src_valid_length_list, tgt_valid_length_list \
                    = [gluon.utils.split_and_load(seq, ctx, batch_axis=0, even_split=False)
                       for seq in [src_seq, tgt_seq, src_valid_length, tgt_valid_length]]
            else:
                src_seq_list = [src_seq.as_in_context(ctx[0])]
                tgt_seq_list = [tgt_seq.as_in_context(ctx[0])]
                src_valid_length_list = [
                    src_valid_length.as_in_context(ctx[0])
                ]
                tgt_valid_length_list = [
                    tgt_valid_length.as_in_context(ctx[0])
                ]

            Ls = []
            with mx.autograd.record():
                for src_seq, tgt_seq, src_valid_length, tgt_valid_length \
                        in zip(src_seq_list, tgt_seq_list,
                               src_valid_length_list, tgt_valid_length_list):
                    out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length,
                                   tgt_valid_length - 1)
                    smoothed_label = label_smoothing(tgt_seq[:, 1:])
                    ls = loss_function(out, smoothed_label,
                                       tgt_valid_length - 1).sum()
                    Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size)
            for L in Ls:
                L.backward()
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                if average_param_dict is None:
                    average_param_dict = {
                        k: v.data(ctx[0]).copy()
                        for k, v in model.collect_params().items()
                    }
                trainer.step(float(loss_denom) / args.batch_size)
                param_dict = model.collect_params()
                param_dict.zero_grad()
                if step_num > average_start:
                    alpha = 1. / max(1, step_num - average_start)
                    for name, average_param in average_param_dict.items():
                        average_param[:] += alpha * (
                            param_dict[name].data(ctx[0]) - average_param)
            step_loss += sum([L.asscalar() for L in Ls])
            if batch_id % grad_interval == grad_interval - 1 or\
                    batch_id == len(train_data_loader) - 1:
                log_avg_loss += step_loss / loss_denom * args.batch_size
                loss_denom = 0
                step_loss = 0
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % (args.log_interval * grad_interval) == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'.format(
                                 epoch_id, batch_id + 1,
                                 len(train_data_loader),
                                 log_avg_loss / args.log_interval,
                                 np.exp(log_avg_loss / args.log_interval),
                                 wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_wc = 0
        mx.nd.waitall()
        valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
        valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                    valid_translation_out,
                                                    bpe=True,
                                                    split_compound_word=True)
        logging.info(
            '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
            .format(epoch_id, valid_loss, np.exp(valid_loss),
                    valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
        test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                                   test_translation_out,
                                                   bpe=True,
                                                   split_compound_word=True)
        logging.info(
            '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
            format(epoch_id, test_loss, np.exp(test_loss),
                   test_bleu_score * 100))
        write_sentences(
            valid_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_valid_out.txt').format(epoch_id))
        write_sentences(
            test_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_params(save_path)
        save_path = os.path.join(args.save_dir,
                                 'epoch{:d}.params'.format(epoch_id))
        model.save_params(save_path)
    save_path = os.path.join(args.save_dir, 'average.params')
    mx.nd.save(save_path, average_param_dict)
    if args.average_checkpoint:
        for j in range(args.num_averages):
            params = mx.nd.load(
                os.path.join(args.save_dir,
                             'epoch{:d}.params'.format(args.epochs - j - 1)))
            alpha = 1. / (j + 1)
            for k, v in model._collect_params_with_prefix().items():
                for c in ctx:
                    v.data(c)[:] += alpha * (params[k].as_in_context(c) -
                                             v.data(c))
    elif args.average_start > 0:
        for k, v in model.collect_params().items():
            v.set_data(average_param_dict[k])
    else:
        model.load_params(os.path.join(args.save_dir, 'valid_best.params'),
                          ctx)
    valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0])
    valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                valid_translation_out,
                                                bpe=True,
                                                split_compound_word=True)
    logging.info(
        'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'.
        format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader, ctx[0])
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                               test_translation_out,
                                               bpe=True,
                                               split_compound_word=True)
    logging.info(
        'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
        format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(args.save_dir, 'best_valid_out.txt'))
    write_sentences(test_translation_out,
                    os.path.join(args.save_dir, 'best_test_out.txt'))
示例#13
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer,
                            {'learning_rate': args.lr})

    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                  btf.Stack(dtype='float32'),
                                  btf.Stack(dtype='float32'))
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
                                 btf.Stack(dtype='float32'),
                                 btf.Stack(dtype='float32'), btf.Stack())
    if args.bucket_scheme == 'constant':
        bucket_scheme = ConstWidthBucket()
    elif args.bucket_scheme == 'linear':
        bucket_scheme = LinearWidthBucket()
    elif args.bucket_scheme == 'exp':
        bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
    else:
        raise NotImplementedError
    train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths,
                                             batch_size=args.batch_size,
                                             num_buckets=args.num_buckets,
                                             ratio=args.bucket_ratio,
                                             shuffle=True,
                                             bucket_scheme=bucket_scheme)
    logging.info('Train Batch Sampler:\n{}'.format(
        train_batch_sampler.stats()))
    train_data_loader = DataLoader(data_train,
                                   batch_sampler=train_batch_sampler,
                                   batchify_fn=train_batchify_fn,
                                   num_workers=8)

    val_batch_sampler = FixedBucketSampler(lengths=data_val_lengths,
                                           batch_size=args.test_batch_size,
                                           num_buckets=args.num_buckets,
                                           ratio=args.bucket_ratio,
                                           shuffle=False)
    logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
    val_data_loader = DataLoader(data_val,
                                 batch_sampler=val_batch_sampler,
                                 batchify_fn=test_batchify_fn,
                                 num_workers=8)
    test_batch_sampler = FixedBucketSampler(lengths=data_test_lengths,
                                            batch_size=args.test_batch_size,
                                            num_buckets=args.num_buckets,
                                            ratio=args.bucket_ratio,
                                            shuffle=False)
    logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
    test_data_loader = DataLoader(data_test,
                                  batch_sampler=test_batch_sampler,
                                  batchify_fn=test_batchify_fn,
                                  num_workers=8)
    best_valid_bleu = 0.0
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_avg_gnorm = 0
        log_wc = 0
        log_start_time = time.time()
        for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\
                in enumerate(train_data_loader):
            # logging.info(src_seq.context) Context suddenly becomes GPU.
            src_seq = src_seq.as_in_context(ctx)
            tgt_seq = tgt_seq.as_in_context(ctx)
            src_valid_length = src_valid_length.as_in_context(ctx)
            tgt_valid_length = tgt_valid_length.as_in_context(ctx)
            with mx.autograd.record():
                out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length,
                               tgt_valid_length - 1)
                loss = loss_function(out, tgt_seq[:, 1:],
                                     tgt_valid_length - 1).mean()
                loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length -
                                                        1).mean()
                loss.backward()
            grads = [p.grad(ctx) for p in model.collect_params().values()]
            gnorm = gluon.utils.clip_global_norm(grads, args.clip)
            trainer.step(1)
            src_wc = src_valid_length.sum().asscalar()
            tgt_wc = (tgt_valid_length - 1).sum().asscalar()
            step_loss = loss.asscalar()
            log_avg_loss += step_loss
            log_avg_gnorm += gnorm
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % args.log_interval == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info(
                    '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, '
                    'throughput={:.2f}K wps, wc={:.2f}K'.format(
                        epoch_id, batch_id + 1, len(train_data_loader),
                        log_avg_loss / args.log_interval,
                        np.exp(log_avg_loss / args.log_interval),
                        log_avg_gnorm / args.log_interval, wps / 1000,
                        log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_avg_gnorm = 0
                log_wc = 0
        valid_loss, valid_translation_out = evaluate(val_data_loader)
        valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                    valid_translation_out)
        logging.info(
            '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
            .format(epoch_id, valid_loss, np.exp(valid_loss),
                    valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader)
        test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                                   test_translation_out)
        logging.info(
            '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
            format(epoch_id, test_loss, np.exp(test_loss),
                   test_bleu_score * 100))
        write_sentences(
            valid_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_valid_out.txt').format(epoch_id))
        write_sentences(
            test_translation_out,
            os.path.join(args.save_dir,
                         'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_params(save_path)
        if epoch_id + 1 >= (args.epochs * 2) // 3:
            new_lr = trainer.learning_rate * args.lr_update_factor
            logging.info('Learning rate change to {}'.format(new_lr))
            trainer.set_learning_rate(new_lr)
    model.load_params(os.path.join(args.save_dir, 'valid_best.params'))
    valid_loss, valid_translation_out = evaluate(val_data_loader)
    valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences],
                                                valid_translation_out)
    logging.info(
        'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'.
        format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader)
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences],
                                               test_translation_out)
    logging.info(
        'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'.
        format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(args.save_dir, 'best_valid_out.txt'))
    write_sentences(test_translation_out,
                    os.path.join(args.save_dir, 'best_test_out.txt'))
示例#14
0
def train():
    """Training function."""
    trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr})

    train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(), btf.Stack())
    test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(), btf.Stack(), btf.Stack())
    train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths,
                                             batch_size=args.batch_size,
                                             num_buckets=args.num_buckets,
                                             ratio=args.bucket_ratio,
                                             shuffle=True)
    logging.info('Train Batch Sampler:\n{}'.format(train_batch_sampler.stats()))
    train_data_loader = DataLoader(data_train,
                                   batch_sampler=train_batch_sampler,
                                   batchify_fn=train_batchify_fn,
                                   num_workers=8)

    val_batch_sampler = FixedBucketSampler(lengths=data_val_lengths,
                                           batch_size=args.test_batch_size,
                                           num_buckets=args.num_buckets,
                                           ratio=args.bucket_ratio,
                                           shuffle=False)
    logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
    val_data_loader = DataLoader(data_val,
                                 batch_sampler=val_batch_sampler,
                                 batchify_fn=test_batchify_fn,
                                 num_workers=8)
    test_batch_sampler = FixedBucketSampler(lengths=data_test_lengths,
                                            batch_size=args.test_batch_size,
                                            num_buckets=args.num_buckets,
                                            ratio=args.bucket_ratio,
                                            shuffle=False)
    logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
    test_data_loader = DataLoader(data_test,
                                  batch_sampler=test_batch_sampler,
                                  batchify_fn=test_batchify_fn,
                                  num_workers=8)
    best_valid_bleu = 0.0
    for epoch_id in range(args.epochs):
        log_avg_loss = 0
        log_avg_gnorm = 0
        log_wc = 0
        log_start_time = time.time()
        for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\
                in enumerate(train_data_loader):
            # logging.info(src_seq.context) Context suddenly becomes GPU.
            src_seq = src_seq.as_in_context(ctx)
            tgt_seq = tgt_seq.as_in_context(ctx)
            src_valid_length = src_valid_length.as_in_context(ctx)
            tgt_valid_length = tgt_valid_length.as_in_context(ctx)
            with mx.autograd.record():
                out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
                loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean()
                loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean()
                loss.backward()
            grads = [p.grad(ctx) for p in model.collect_params().values()]
            gnorm = gluon.utils.clip_global_norm(grads, args.clip)
            trainer.step(1)
            src_wc = src_valid_length.sum().asscalar()
            tgt_wc = (tgt_valid_length - 1).sum().asscalar()
            step_loss = loss.asscalar()
            log_avg_loss += step_loss
            log_avg_gnorm += gnorm
            log_wc += src_wc + tgt_wc
            if (batch_id + 1) % args.log_interval == 0:
                wps = log_wc / (time.time() - log_start_time)
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'
                             .format(epoch_id, batch_id + 1, len(train_data_loader),
                                     log_avg_loss / args.log_interval,
                                     np.exp(log_avg_loss / args.log_interval),
                                     log_avg_gnorm / args.log_interval,
                                     wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_avg_gnorm = 0
                log_wc = 0
        valid_loss, valid_translation_out = evaluate(val_data_loader)
        valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out)
        logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
                     .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
        test_loss, test_translation_out = evaluate(test_data_loader)
        test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out)
        logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
                     .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100))
        write_sentences(valid_translation_out,
                        os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id))
        write_sentences(test_translation_out,
                        os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id))
        if valid_bleu_score > best_valid_bleu:
            best_valid_bleu = valid_bleu_score
            save_path = os.path.join(args.save_dir, 'valid_best.params')
            logging.info('Save best parameters to {}'.format(save_path))
            model.save_params(save_path)
        else:
            new_lr = trainer.learning_rate * args.lr_update_factor
            logging.info('Learning rate change to {}'.format(new_lr))
            trainer.set_learning_rate(new_lr)
    model.load_params(os.path.join(args.save_dir, 'valid_best.params'))
    valid_loss, valid_translation_out = evaluate(val_data_loader)
    valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out)
    logging.info('Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
                 .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader)
    test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out)
    logging.info('Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
                 .format(test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(args.save_dir, 'best_valid_out.txt'))
    write_sentences(test_translation_out,
                    os.path.join(args.save_dir, 'best_test_out.txt'))