Exemplo n.º 1
0
def test_transformer_param_number(cfg_name, gt_num_params, gt_num_fixed_params):
    cfg = TransformerModel.get_cfg(cfg_name)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = 32768
    cfg.MODEL.tgt_vocab_size = 32768
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize()
    num_params, num_fixed_params = count_parameters(model.collect_params())
    assert num_params == gt_num_params
    assert num_fixed_params == gt_num_fixed_params
    num_params2, num_fixed_params2 = count_parameters(deduplicate_param_dict(model.collect_params()))
    assert num_params2 == gt_num_params
    assert num_fixed_params2 == gt_num_fixed_params
Exemplo n.º 2
0
def test_transformer_cfg(cfg_key):
    cfg = TransformerModel.get_cfg(cfg_key)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = 32
    cfg.MODEL.tgt_vocab_size = 32
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize()
    model.hybridize()
    cfg.defrost()
    cfg.MODEL.layout = 'TN'
    cfg.freeze()
    model_tn = TransformerModel.from_cfg(cfg)
    model_tn.share_parameters(model.collect_params())
    model_tn.hybridize()
    mx.npx.waitall()
Exemplo n.º 3
0
def test_transformer_fp16_amp(enc_pre_norm, dec_pre_norm,
                              enc_units, dec_units,
                              enc_num_layers, dec_num_layers,
                              enc_recurrent, dec_recurrent, tie_weights,
                              layout, ctx):
    if ctx.device_type != 'gpu':
        pytest.skip('Only test amp when running on GPU.')
    # Generate configuration for testing
    cfg = TransformerModel.get_cfg()
    cfg.defrost()
    cfg.MODEL.src_vocab_size = 32
    cfg.MODEL.tgt_vocab_size = 32
    cfg.MODEL.max_src_length = 20
    cfg.MODEL.max_tgt_length = 15
    cfg.MODEL.tie_weights = tie_weights
    cfg.MODEL.layout = layout

    # Encoder config
    cfg.MODEL.ENCODER.pre_norm = enc_pre_norm
    cfg.MODEL.ENCODER.units = enc_units
    cfg.MODEL.ENCODER.num_layers = enc_num_layers
    cfg.MODEL.ENCODER.recurrent = enc_recurrent

    # Decoder config
    cfg.MODEL.DECODER.pre_norm = dec_pre_norm
    cfg.MODEL.DECODER.units = dec_units
    cfg.MODEL.DECODER.num_layers = dec_num_layers
    cfg.MODEL.DECODER.recurrent = dec_recurrent
    cfg.freeze()

    batch_size = 4
    seq_length = 16
    with ctx:
        if layout == 'NT':
            src_data = mx.np.random.randint(0, cfg.MODEL.src_vocab_size,
                                            (batch_size, seq_length), dtype=np.int32)
            src_valid_length = mx.np.random.randint(seq_length // 2, seq_length,
                                                    (batch_size,), dtype=np.int32)
            tgt_data = mx.np.random.randint(0, cfg.MODEL.tgt_vocab_size,
                                            (batch_size, seq_length), dtype=np.int32)
            tgt_valid_length = mx.np.random.randint(seq_length // 2, seq_length,
                                                    (batch_size,), dtype=np.int32)
        elif layout == 'TN':
            src_data = mx.np.random.randint(0, cfg.MODEL.src_vocab_size,
                                            (seq_length, batch_size), dtype=np.int32)
            src_valid_length = mx.np.random.randint(seq_length // 2, seq_length,
                                                    (batch_size,), dtype=np.int32)
            tgt_data = mx.np.random.randint(0, cfg.MODEL.tgt_vocab_size,
                                            (seq_length, batch_size), dtype=np.int32)
            tgt_valid_length = mx.np.random.randint(seq_length // 2, seq_length,
                                                    (batch_size,), dtype=np.int32)
        else:
            raise NotImplementedError
        verify_backbone_fp16(TransformerModel, cfg, ctx,
                             inputs=[src_data, src_valid_length, tgt_data, tgt_valid_length])
Exemplo n.º 4
0
def test_transformer_nmt_model(train_hybridize, inference_hybridize,
                               enc_pre_norm, dec_pre_norm,
                               enc_units, dec_units,
                               enc_num_layers, dec_num_layers,
                               enc_recurrent, dec_recurrent, tie_weights,
                               layout):
    if inference_hybridize:
        pytest.skip('inference model hybridization is not working')
    src_seq_length = 20
    tgt_seq_length = 15
    src_vocab_size = 32
    tgt_vocab_size = 32
    if enc_units != dec_units:
        shared_embed = False
    else:
        shared_embed = True
    model = TransformerModel(src_vocab_size=src_vocab_size,
                             tgt_vocab_size=tgt_vocab_size,
                             max_src_length=src_seq_length,
                             max_tgt_length=tgt_seq_length,
                             enc_units=enc_units,
                             enc_hidden_size=64,
                             enc_num_heads=4,
                             enc_num_layers=enc_num_layers,
                             enc_pre_norm=enc_pre_norm,
                             enc_recurrent=enc_recurrent,
                             dec_units=dec_units,
                             dec_hidden_size=64,
                             dec_num_heads=4,
                             dec_num_layers=dec_num_layers,
                             dec_pre_norm=dec_pre_norm,
                             dec_recurrent=dec_recurrent,
                             shared_embed=shared_embed,
                             tie_weights=tie_weights,
                             dropout=0.0,
                             layout=layout)
    inference_model = TransformerNMTInference(model=model)
    model.initialize()
    if train_hybridize:
        model.hybridize()
    verify_nmt_model(model)
    if inference_hybridize:
        inference_model.hybridize()
    verify_nmt_inference(train_model=model, inference_model=inference_model)
Exemplo n.º 5
0
def train(args):
    _, num_parts, rank, local_rank, _, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    if args.comm_backend == 'horovod':
        logging_config(
            args.save_dir,
            name=f'train_transformer_rank{rank}_local{local_rank}_{num_parts}',
            console=(rank == 0))
        logging.info(args)
    else:
        logging_config(args.save_dir, name='train_transformer', console=True)
        logging.info(args)
    use_amp = args.fp16
    if use_amp:
        from mxnet import amp
    src_tokenizer = create_tokenizer(args.src_tokenizer,
                                     args.src_subword_model_path,
                                     args.src_vocab_path)
    tgt_tokenizer = create_tokenizer(args.tgt_tokenizer,
                                     args.tgt_subword_model_path,
                                     args.tgt_vocab_path)
    base_tgt_tokenizer = MosesTokenizer(args.tgt_lang)
    src_vocab = src_tokenizer.vocab
    tgt_vocab = tgt_tokenizer.vocab
    train_src_data, train_tgt_data = load_dataset_with_cache(
        args.train_src_corpus,
        args.train_tgt_corpus,
        src_tokenizer,
        tgt_tokenizer,
        args.overwrite_cache,
        local_rank,
        max_src_length=args.max_src_length,
        max_tgt_length=args.max_tgt_length,
        pretokenized=not args.tokenize)
    dev_src_data, dev_tgt_data = load_dataset_with_cache(
        args.dev_src_corpus,
        args.dev_tgt_corpus,
        src_tokenizer,
        tgt_tokenizer,
        args.overwrite_cache,
        local_rank,
        pretokenized=not args.tokenize)
    tgt_detok_sentences = []
    tgt_raw_sentences = []
    with open(args.dev_tgt_corpus, 'r') as in_f:
        for line in in_f:
            tgt_detok_sentences.append(
                base_tgt_tokenizer.decode(
                    tgt_tokenizer.decode(line.split()).split()))
    with open(args.dev_tgt_raw_corpus, 'r') as in_f:
        for line in in_f:
            tgt_raw_sentences.append(line.strip())
    data_train = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))
    ])
    val_samples = [
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))
    ]
    if args.comm_backend == 'horovod':
        slice_begin = rank * (len(val_samples) // num_parts)
        slice_end = min((rank + 1) * (len(val_samples) // num_parts),
                        len(val_samples))
        data_val = gluon.data.SimpleDataset(val_samples[slice_begin:slice_end])
    else:
        data_val = gluon.data.SimpleDataset(val_samples)
    # Construct the model + loss function
    if args.cfg.endswith('.yml'):
        cfg = TransformerModel.get_cfg().clone_merge(args.cfg)
    else:
        cfg = TransformerModel.get_cfg(args.cfg)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = len(src_vocab)
    cfg.MODEL.tgt_vocab_size = len(tgt_vocab)
    cfg.MODEL.layout = 'TN'
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l)
    model.hybridize()
    for v in model.collect_params().values():
        if v.grad_req != 'null':
            v.grad_req = 'add'
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    param_dict = deduplicate_param_dict(model.collect_params())

    inference_model = TransformerInference(model=model)
    inference_model.hybridize()
    if local_rank == 0:
        logging.info(model)
    with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f:
        cfg_f.write(cfg.dump())
    label_smooth_loss = LabelSmoothCrossEntropyLoss(
        num_labels=len(tgt_vocab),
        alpha=args.label_smooth_alpha,
        from_logits=False)
    label_smooth_loss.hybridize()

    # Construct the beam search sampler
    scorer = BeamSearchScorer(alpha=args.lp_alpha,
                              K=args.lp_k,
                              from_logits=False)
    beam_search_sampler = BeamSearchSampler(beam_size=args.beam_size,
                                            decoder=inference_model,
                                            vocab_size=len(tgt_vocab),
                                            eos_id=tgt_vocab.eos_id,
                                            scorer=scorer,
                                            stochastic=False,
                                            max_length_a=args.max_length_a,
                                            max_length_b=args.max_length_b)

    logging.info(beam_search_sampler)
    if args.comm_backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    # Construct the trainer
    if args.lr is None:
        base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt(
            args.warmup_steps)
    else:
        base_lr = args.lr
    lr_scheduler = InverseSquareRootScheduler(
        warmup_steps=args.warmup_steps,
        base_lr=base_lr,
        warmup_init_lr=args.warmup_init_lr)
    optimizer_params = {
        'learning_rate': args.lr,
        'beta1': 0.9,
        'beta2': 0.997,
        'epsilon': 1e-9,
        'lr_scheduler': lr_scheduler,
        'wd': args.wd
    }
    user_provided_ptimizer_params = json.loads(args.optimizer_params)
    optimizer_params.update(user_provided_ptimizer_params)

    if args.fp16:
        optimizer_params.update({'multi_precision': True})
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = gluon.Trainer(param_dict,
                                args.optimizer,
                                optimizer_params,
                                update_on_kvstore=False)
    # Load Data
    if args.sampler == 'BoundedBudgetSampler':
        train_batch_sampler = BoundedBudgetSampler(
            lengths=[(ele[2], ele[3]) for ele in data_train],
            max_num_tokens=args.max_num_tokens,
            max_num_sentences=args.max_num_sentences,
            shuffle=True,
            seed=args.seed)
    elif args.sampler == 'FixedBucketSampler':
        if args.comm_backend == 'horovod':
            raise NotImplementedError(
                'FixedBucketSampler does not support horovod at present')

        if args.bucket_scheme == 'constant':
            bucket_scheme = ConstWidthBucket()
        elif args.bucket_scheme == 'linear':
            bucket_scheme = LinearWidthBucket()
        elif args.bucket_scheme == 'exp':
            bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
        else:
            raise NotImplementedError
        # TODO(sxjscience) Support auto-bucket-size tuning
        train_batch_sampler = FixedBucketSampler(lengths=[
            (ele[2], ele[3]) for ele in data_train
        ],
                                                 batch_size=args.batch_size,
                                                 num_buckets=args.num_buckets,
                                                 ratio=args.bucket_ratio,
                                                 shuffle=True,
                                                 use_average_length=True,
                                                 bucket_scheme=bucket_scheme,
                                                 seed=args.seed)
    else:
        raise NotImplementedError

    num_updates_per_epoch = int(
        math.ceil(
            len(train_batch_sampler) /
            (num_parts * len(ctx_l) * args.num_accumulated)))
    # Convert the batch sampler to multiple shards
    if num_parts > 1:
        train_batch_sampler = ShardedIterator(train_batch_sampler,
                                              num_parts=num_parts,
                                              part_index=rank,
                                              even_size=True,
                                              seed=args.seed + 1000 * rank)

    logging.info(train_batch_sampler)

    batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(),
                           bf.Stack())
    train_data_loader = gluon.data.DataLoader(
        data_train,
        batch_sampler=train_batch_sampler,
        batchify_fn=batchify_fn,
        num_workers=0)
    val_data_loader = gluon.data.DataLoader(data_val,
                                            batch_size=args.val_batch_size,
                                            batchify_fn=batchify_fn,
                                            num_workers=0,
                                            shuffle=False)
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    model_averager = AverageSGDTracker(param_dict)
    log_start_time = time.time()
    num_params, num_fixed_params = None, None

    # TODO(sxjscience) Add a log metric class
    log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
    # Maintain the denominator of the loss.
    log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
    log_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l]
    log_tgt_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l]
    log_avg_grad_norm = 0
    log_iter_num = 0

    if local_rank == 0:
        writer = SummaryWriter(
            logdir=os.path.join(args.save_dir, 'tensorboard'))
    if use_amp:
        amp.init_trainer(trainer)
    train_multi_data_loader = grouper(repeat(train_data_loader), len(ctx_l))
    # when args.epochs < 0, the model will keep training
    if args.epochs < 0:
        if args.max_update > 0:
            total_train_iters = args.max_update
            if args.num_averages > 0:
                assert args.num_averages <= total_train_iters // args.save_iterval_update
                avg_start_iter = (
                    total_train_iters // args.save_iterval_update -
                    args.num_averages) * args.save_iterval_update
            else:
                avg_start_iter = -1
        else:
            total_train_iters = np.inf
            avg_start_iter = -1
    else:
        total_train_iters = args.epochs * num_updates_per_epoch
        if args.num_averages > 0:
            assert args.num_averages <= args.epochs
            avg_start_iter = (args.epochs -
                              args.num_average) * num_updates_per_epoch
        else:
            avg_start_iter = -1

    # Here, we are manually setting up the scale to 1.0 because
    # in horovod, the scale can be the number of workers:
    # See the code here: https://github.com/horovod/horovod/blob/125115583b7029196e2ec530decd4209459d5479/horovod/mxnet/__init__.py#L141
    # Since we will need to use the dynamic scaling in amp, we will manually call amp.unscale().
    # A scale that is larger than 1.0 can be problematic in this case.
    trainer._scale = 1.0
    if args.max_num_tokens > 0:
        const_scale = args.max_num_tokens
    else:
        const_scale = 100

    train_start_time = time.time()

    for train_iter in range(total_train_iters):
        model.zero_grad()
        loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
        for i in range(args.num_accumulated):
            loss_l = []
            sample_data_l = next(train_multi_data_loader)
            for j, (sample_data, ctx) in enumerate(zip(sample_data_l, ctx_l)):
                src_token_ids, tgt_token_ids, src_valid_length,\
                tgt_valid_length, sample_ids = sample_data
                src_token_ids = src_token_ids.as_in_ctx(ctx)
                tgt_token_ids = tgt_token_ids.as_in_ctx(ctx)
                src_valid_length = src_valid_length.as_in_ctx(ctx)
                tgt_valid_length = tgt_valid_length.as_in_ctx(ctx)
                src_wc, tgt_wc, bs = src_valid_length.sum(), \
                                     tgt_valid_length.sum(), src_token_ids.shape[0]
                log_wc_l[j] += src_wc + tgt_wc
                log_tgt_wc_l[j] += tgt_wc
                token_count = (tgt_valid_length - 1).sum()
                loss_denom_l[j] += token_count / const_scale
                log_avg_loss_denom_l[j] += token_count / const_scale
                with mx.autograd.record():
                    if model.layout == 'NT':
                        tgt_pred = model(src_token_ids, src_valid_length,
                                         tgt_token_ids[:, :-1],
                                         tgt_valid_length - 1)
                        tgt_labels = tgt_token_ids[:, 1:]
                        loss = label_smooth_loss(tgt_pred, tgt_labels)
                        loss = mx.npx.sequence_mask(
                            loss,
                            sequence_length=tgt_valid_length - 1,
                            use_sequence_length=True,
                            axis=1)
                        loss = loss.sum() / const_scale
                        loss_l.append(loss)
                    elif model.layout == 'TN':
                        tgt_pred = model(src_token_ids.T, src_valid_length,
                                         tgt_token_ids.T[:-1, :],
                                         tgt_valid_length - 1)
                        tgt_labels = tgt_token_ids.T[1:, :]
                        loss = label_smooth_loss(tgt_pred, tgt_labels)
                        loss = mx.npx.sequence_mask(
                            loss,
                            sequence_length=tgt_valid_length - 1,
                            use_sequence_length=True,
                            axis=0)
                        loss = loss.sum() / const_scale
                        loss_l.append(loss)
                log_avg_loss_l[j] += loss
            if use_amp:
                with mx.autograd.record():
                    with amp.scale_loss(loss_l, trainer) as amp_loss_l:
                        for loss in amp_loss_l:
                            loss.backward()
            else:
                with mx.autograd.record():
                    for loss in loss_l:
                        loss.backward()

        # Print the total number of parameters
        if local_rank == 0 and num_params is None:
            num_params, num_fixed_params = count_parameters(param_dict)
            logging.info(
                'Total Number of Parameters (not-fixed/fixed): {}/{}'.format(
                    num_params, num_fixed_params))
        # All-Reduce the gradient
        trainer.allreduce_grads()
        if args.comm_backend == 'horovod':
            # All-Reduce the loss denominator
            assert len(loss_denom_l) == 1
            loss_denom = hvd.allreduce(loss_denom_l[0],
                                       average=False).asnumpy()
        else:
            loss_denom = sum([ele.asnumpy() for ele in loss_denom_l])
        if use_amp:
            # We need to first unscale the gradient and then perform allreduce.
            grad_scale = trainer.amp_loss_scale * loss_denom
        else:
            grad_scale = loss_denom
        if args.max_grad_norm is not None:
            total_norm, ratio, is_finite\
                = clip_grad_global_norm(params, args.max_grad_norm * grad_scale)
            total_norm = total_norm / grad_scale
        else:
            total_norm = grad_global_norm(params)
            total_norm = total_norm / grad_scale
        log_avg_grad_norm += total_norm
        log_iter_num += 1

        trainer.update(loss_denom, ignore_stale_grad=True)

        if avg_start_iter > 0 and train_iter >= avg_start_iter:
            model_averager.step()

        if ((train_iter + 1) % args.log_interval == 0
                or train_iter + 1 == total_train_iters):
            if args.comm_backend == 'horovod':
                # Use allreduce to get the total number of tokens and loss
                log_wc = hvd.allreduce(log_wc_l[0], average=False).asnumpy()
                log_tgt_wc = hvd.allreduce(log_tgt_wc_l[0],
                                           average=False).asnumpy()
                log_avg_loss = hvd.allreduce(log_avg_loss_l[0] /
                                             log_avg_loss_denom_l[0],
                                             average=True)
                log_avg_loss = log_avg_loss.asnumpy()
            else:
                log_wc = sum([ele.asnumpy() for ele in log_wc_l])
                log_tgt_wc = sum([ele.asnumpy() for ele in log_tgt_wc_l])
                log_avg_loss =\
                    sum([log_avg_loss_l[i].asnumpy() / log_avg_loss_denom_l[i].asnumpy()
                         for i in range(len(log_avg_loss_l))]) / len(log_avg_loss_l)
            log_avg_grad_norm = log_avg_grad_norm / log_iter_num
            log_end_time = time.time()
            wps = log_wc / (log_end_time - log_start_time)
            epoch_id = train_iter // num_updates_per_epoch
            logging.info(
                '[Epoch {} Iter {}/{}, Overall {}/{}] loss={:.4f}, ppl={:.4f}, '
                'throughput={:.2f}K wps, total wc={:.2f}K, wpb={:.2f}K,'
                ' LR={}, gnorm={:.4f}, ETA={:.2f}h'.format(
                    epoch_id, train_iter % num_updates_per_epoch + 1,
                    num_updates_per_epoch,
                    train_iter + 1, total_train_iters, log_avg_loss,
                    np.exp(log_avg_loss), wps / 1000, log_wc / 1000,
                    log_tgt_wc / 1000 / log_iter_num, trainer.learning_rate,
                    log_avg_grad_norm,
                    (log_end_time - train_start_time) / (train_iter + 1) *
                    (total_train_iters - train_iter - 1) / 3600))
            if local_rank == 0:
                writer.add_scalar('throughput_wps', wps, train_iter)
                writer.add_scalar('train_loss', log_avg_loss, train_iter)
                writer.add_scalar('lr', trainer.learning_rate, train_iter)
                writer.add_scalar('grad_norm', log_avg_grad_norm, train_iter)
            # Reinitialize the log variables
            log_start_time = time.time()
            log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
            log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l]
            log_avg_grad_norm = 0
            log_iter_num = 0
            log_wc_l = [
                mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l
            ]
            log_tgt_wc_l = [
                mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l
            ]

        if (args.max_update > 0 and (train_iter + 1) % args.save_interval_update == 0) \
            or ((train_iter + 1) % num_updates_per_epoch == 0) \
            or train_iter + 1 == total_train_iters:
            epoch_id = (train_iter + 1) // num_updates_per_epoch
            if local_rank == 0:
                if args.max_update <= 0:
                    model.save_parameters(os.path.join(
                        args.save_dir, 'epoch{}.params'.format(epoch_id)),
                                          deduplicate=True)
                else:
                    model.save_parameters(os.path.join(
                        args.save_dir, 'iter{}.params'.format(train_iter + 1)),
                                          deduplicate=True)

            avg_val_loss, ntokens, pred_sentences, pred_lengths, sentence_ids\
                = validation(model, val_data_loader, inference_model, beam_search_sampler,
                             tgt_tokenizer, ctx_l)
            if args.comm_backend == 'horovod':
                flatten_pred_sentences = np.concatenate(pred_sentences, axis=0)
                all_val_loss = hvd.allgather(
                    mx.np.array([avg_val_loss * ntokens],
                                dtype=np.float32,
                                ctx=ctx_l[0]))
                all_ntokens = hvd.allgather(
                    mx.np.array([ntokens], dtype=np.int64, ctx=ctx_l[0]))
                flatten_pred_sentences = hvd.allgather(
                    mx.np.array(flatten_pred_sentences,
                                dtype=np.int32,
                                ctx=ctx_l[0]))
                pred_lengths = hvd.allgather(
                    mx.np.array(pred_lengths, dtype=np.int64, ctx=ctx_l[0]))
                sentence_ids = hvd.allgather(
                    mx.np.array(sentence_ids, dtype=np.int64, ctx=ctx_l[0]))
                avg_val_loss = all_val_loss.asnumpy().sum(
                ) / all_ntokens.asnumpy().sum()
                flatten_pred_sentences = flatten_pred_sentences.asnumpy()
                pred_lengths = pred_lengths.asnumpy()
                sentence_ids = sentence_ids.asnumpy()
                pred_sentences = [None for _ in range(len(sentence_ids))]
                ptr = 0
                assert sentence_ids.min() == 0 and sentence_ids.max(
                ) == len(sentence_ids) - 1
                for sentence_id, length in zip(sentence_ids, pred_lengths):
                    pred_sentences[sentence_id] = flatten_pred_sentences[ptr:(
                        ptr + length)]
                    ptr += length
            if local_rank == 0:
                # Perform detokenization
                pred_sentences_bpe_decode = []
                pred_sentences_raw = []
                for sentence in pred_sentences:
                    bpe_decode_sentence = tgt_tokenizer.decode(
                        sentence.tolist())
                    raw_sentence = base_tgt_tokenizer.decode(
                        bpe_decode_sentence.split())
                    pred_sentences_bpe_decode.append(bpe_decode_sentence)
                    pred_sentences_raw.append(raw_sentence)
                detok_sacrebleu_out = sacrebleu.corpus_bleu(
                    sys_stream=pred_sentences_bpe_decode,
                    ref_streams=[tgt_detok_sentences])
                raw_sacrebleu_out = sacrebleu.corpus_bleu(
                    sys_stream=pred_sentences_raw,
                    ref_streams=[tgt_raw_sentences])
                with open(
                        os.path.join(args.save_dir,
                                     f'epoch{epoch_id}_dev_prediction.txt'),
                        'w') as of:
                    for line in pred_sentences_raw:
                        of.write(line + '\n')
                logging.info(
                    '[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, '
                    'SacreBlEU={}, Detok SacreBLUE={}'.format(
                        epoch_id, train_iter, total_train_iters, avg_val_loss,
                        np.exp(avg_val_loss), raw_sacrebleu_out.score,
                        detok_sacrebleu_out.score))
                writer.add_scalar('valid_loss', avg_val_loss, train_iter)
                writer.add_scalar('valid_bleu', raw_sacrebleu_out.score,
                                  train_iter)

    if args.num_averages > 0:
        model_averager.copy_back(
            param_dict)  # TODO(sxjscience) Rewrite using update
        model.save_parameters(os.path.join(args.save_dir, 'average.params'),
                              deduplicate=True)
Exemplo n.º 6
0
def train(args):
    store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    src_tokenizer = create_tokenizer(args.src_tokenizer,
                                     args.src_subword_model_path,
                                     args.src_vocab_path)
    tgt_tokenizer = create_tokenizer(args.tgt_tokenizer,
                                     args.tgt_subword_model_path,
                                     args.tgt_vocab_path)
    src_vocab = src_tokenizer.vocab
    tgt_vocab = tgt_tokenizer.vocab
    train_src_data, train_tgt_data = load_dataset_with_cache(
        args.train_src_corpus, args.train_tgt_corpus, src_tokenizer,
        tgt_tokenizer, args.overwrite_cache)
    dev_src_data, dev_tgt_data = load_dataset_with_cache(
        args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer,
        args.overwrite_cache)
    data_train = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))
    ])
    data_val = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))
    ])
    # Construct the model + loss function
    if args.cfg.endswith('.yml'):
        cfg = TransformerModel.get_cfg().clone_merge(args.cfg)
    else:
        cfg = TransformerModel.get_cfg(args.cfg)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = len(src_vocab)
    cfg.MODEL.tgt_vocab_size = len(tgt_vocab)
    if args.fp16:
        raise NotImplementedError


#        cfg.MODEL.dtype = 'float16'
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l)
    model.hybridize()
    if local_rank == 0:
        logging.info(model)
    with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f:
        cfg_f.write(cfg.dump())
    label_smooth_loss = LabelSmoothCrossEntropyLoss(
        num_labels=len(tgt_vocab),
        alpha=args.label_smooth_alpha,
        from_logits=False)
    label_smooth_loss.hybridize()
    rescale_loss = 100.0

    if args.comm_backend == 'horovod':
        hvd.broadcast_parameters(model.collect_params(), root_rank=0)

    # Construct the trainer
    # TODO(sxjscience) Support AMP
    if args.lr is None:
        base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt(
            args.warmup_steps)
    else:
        base_lr = args.lr
    lr_scheduler = InverseSquareRootScheduler(
        warmup_steps=args.warmup_steps,
        base_lr=base_lr,
        warmup_init_lr=args.warmup_init_lr)
    trainer_settings = (model.collect_params(), 'adam', {
        'learning_rate': args.lr,
        'beta1': 0.9,
        'beta2': 0.98,
        'epsilon': 1e-9,
        'lr_scheduler': lr_scheduler
    })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(*trainer_settings)
    else:
        trainer = gluon.Trainer(*trainer_settings)
    # Load Data
    if args.sampler == 'BoundedBudgetSampler':
        train_batch_sampler = BoundedBudgetSampler(
            lengths=[(ele[2], ele[3]) for ele in data_train],
            max_num_tokens=args.max_num_tokens,
            max_num_sentences=args.max_num_sentences,
            seed=args.seed,
            num_parts=num_parts,
            part_index=rank)
    elif args.sampler == 'FixedBucketSampler':
        if args.comm_backend == 'horovod':
            raise NotImplementedError(
                'FixedBucketSampler does not support horovod at present')

        if args.bucket_scheme == 'constant':
            bucket_scheme = ConstWidthBucket()
        elif args.bucket_scheme == 'linear':
            bucket_scheme = LinearWidthBucket()
        elif args.bucket_scheme == 'exp':
            bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
        else:
            raise NotImplementedError
        # TODO(sxjscience) Support auto-bucket-size tuning
        train_batch_sampler = FixedBucketSampler(lengths=[
            (ele[2], ele[3]) for ele in data_train
        ],
                                                 batch_size=args.batch_size,
                                                 num_buckets=args.num_buckets,
                                                 ratio=args.bucket_ratio,
                                                 shuffle=True,
                                                 use_average_length=True,
                                                 bucket_scheme=bucket_scheme,
                                                 seed=args.seed)
    else:
        raise NotImplementedError

    if local_rank == 0:
        logging.info(train_batch_sampler)

    batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(),
                           bf.Stack())
    train_data_loader = gluon.data.DataLoader(
        data_train,
        batch_sampler=train_batch_sampler,
        batchify_fn=batchify_fn,
        num_workers=0)

    val_data_loader = gluon.data.DataLoader(data_val,
                                            batch_size=args.val_batch_size,
                                            batchify_fn=batchify_fn,
                                            num_workers=0,
                                            shuffle=False)
    for v in model.collect_params().values():
        if v.grad_req != 'null':
            v.grad_req = 'add'
    model.zero_grad()
    model_averager = AverageSGDTracker(model.collect_params())
    log_start_time = time.time()
    num_params, num_fixed_params = None, None
    # TODO(sxjscience) Add a log metric class
    accum_count = 0
    loss_denom = 0
    n_train_iters = 0
    log_wc = 0
    log_avg_loss = 0.0
    log_loss_denom = 0
    epoch_id = 0
    while (args.epochs < 0 or epoch_id < args.epochs
           ):  # when args.epochs < 0, the model will keep training
        n_epoch_train_iters = 0
        processed_batch_num = 0
        train_multi_data_loader = grouper(train_data_loader, len(ctx_l))
        is_last_batch = False
        sample_data_l = next(train_multi_data_loader)
        while not is_last_batch:
            processed_batch_num += len(sample_data_l)
            loss_l = []
            for sample_data, ctx in zip(sample_data_l, ctx_l):
                if sample_data is None:
                    continue
                src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data
                src_wc, tgt_wc, bs = src_valid_length.sum(
                ), tgt_valid_length.sum(), src_token_ids.shape[0]
                loss_denom += tgt_wc - bs
                log_loss_denom += tgt_wc - bs
                log_wc += src_wc + tgt_wc
                src_token_ids = src_token_ids.as_in_ctx(ctx)
                tgt_token_ids = tgt_token_ids.as_in_ctx(ctx)
                src_valid_length = src_valid_length.as_in_ctx(ctx)
                tgt_valid_length = tgt_valid_length.as_in_ctx(ctx)
                with mx.autograd.record():
                    tgt_pred = model(src_token_ids, src_valid_length,
                                     tgt_token_ids[:, :-1],
                                     tgt_valid_length - 1)
                    tgt_labels = tgt_token_ids[:, 1:]
                    loss = label_smooth_loss(tgt_pred, tgt_labels)
                    loss = mx.npx.sequence_mask(
                        loss,
                        sequence_length=tgt_valid_length - 1,
                        use_sequence_length=True,
                        axis=1)
                    loss_l.append(loss.sum() / rescale_loss)
            for l in loss_l:
                l.backward()
            accum_count += 1
            try:
                sample_data_l = next(train_multi_data_loader)
            except StopIteration:
                is_last_batch = True
            if local_rank == 0 and num_params is None:
                num_params, num_fixed_params = count_parameters(
                    model.collect_params())
                logging.info(
                    'Total Number of Parameters (not-fixed/fixed): {}/{}'.
                    format(num_params, num_fixed_params))
            sum_loss = sum([l.as_in_ctx(mx.cpu())
                            for l in loss_l]) * rescale_loss
            log_avg_loss += sum_loss
            mx.npx.waitall()
            if accum_count == args.num_accumulated or is_last_batch:
                # Update the parameters
                n_train_iters += 1
                n_epoch_train_iters += 1
                trainer.step(loss_denom.asnumpy() / rescale_loss)
                accum_count = 0
                loss_denom = 0
                model.zero_grad()
                if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \
                   (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update):
                    model_averager.step()
                if local_rank == 0 and \
                   (n_epoch_train_iters % args.log_interval == 0 or is_last_batch):
                    log_end_time = time.time()
                    log_wc = log_wc.asnumpy()
                    wps = log_wc / (log_end_time - log_start_time)
                    log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy()
                    logging.info(
                        '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                        'throughput={:.2f}K wps, wc={:.2f}K, LR={}'.format(
                            epoch_id, processed_batch_num * num_parts,
                            len(train_data_loader), log_avg_loss,
                            np.exp(log_avg_loss), wps / 1000, log_wc / 1000,
                            trainer.learning_rate))
                    log_start_time = time.time()
                    log_avg_loss = 0
                    log_loss_denom = 0
                    log_wc = 0
                if local_rank == 0 and \
                   (args.max_update > 0 and n_train_iters % args.save_interval_update == 0):
                    model.save_parameters(os.path.join(
                        args.save_dir, 'update{:d}.params'.format(
                            n_train_iters // args.save_interval_update)),
                                          deduplicate=True)
                if args.max_update > 0 and n_train_iters >= args.max_update:
                    break
        if local_rank == 0 and args.epochs > 0:
            model.save_parameters(os.path.join(
                args.save_dir, 'epoch{:d}.params'.format(epoch_id)),
                                  deduplicate=True)
        avg_valid_loss = validation(model, val_data_loader, ctx_l)
        logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'.format(
            epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

        if args.max_update > 0 and n_train_iters >= args.max_update:
            break
        epoch_id += 1

    if args.num_averages > 0:
        model_averager.copy_back(
            model.collect_params())  # TODO(sxjscience) Rewrite using update
        model.save_parameters(os.path.join(args.save_dir, 'average.params'),
                              deduplicate=True)
Exemplo n.º 7
0
def evaluate(args):
    ctx_l = [mx.cpu()] if args.gpus is None or args.gpus == '' else [mx.gpu(int(x)) for x in
                                                                     args.gpus.split(',')]
    src_normalizer = get_normalizer(args.src_normalizer, args.src_lang)
    tgt_normalizer = get_normalizer(args.src_normalizer, args.tgt_lang)
    base_src_tokenizer = get_base_tokenizer(args.src_base_tokenizer, args.src_lang)
    base_tgt_tokenizer = get_base_tokenizer(args.tgt_base_tokenizer, args.tgt_lang)

    src_tokenizer = create_tokenizer(args.src_tokenizer,
                                     args.src_subword_model_path,
                                     args.src_vocab_path)
    tgt_tokenizer = create_tokenizer(args.tgt_tokenizer,
                                     args.tgt_subword_model_path,
                                     args.tgt_vocab_path)
    src_vocab = src_tokenizer.vocab
    tgt_vocab = tgt_tokenizer.vocab
    if args.cfg.endswith('.yml'):
        cfg = TransformerModel.get_cfg().clone_merge(args.cfg)
    else:
        cfg = TransformerModel.get_cfg(args.cfg)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = len(src_vocab)
    cfg.MODEL.tgt_vocab_size = len(tgt_vocab)
    if args.fp16:
        cfg.MODEL.dtype = 'float16'
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.cast('float16')
    model.hybridize()
    model.load_parameters(args.param_path, ctx=ctx_l, cast_dtype=True)
    inference_model = TransformerInference(model=model)
    inference_model.hybridize()
    # Construct the BeamSearchSampler
    if args.stochastic:
        scorer = BeamSearchScorer(alpha=0.0,
                                  K=0.0,
                                  temperature=args.temperature,
                                  from_logits=False)
    else:
        scorer = BeamSearchScorer(alpha=args.lp_alpha,
                                  K=args.lp_k,
                                  from_logits=False)
    beam_search_sampler = BeamSearchSampler(beam_size=args.beam_size,
                                            decoder=inference_model,
                                            vocab_size=len(tgt_vocab),
                                            eos_id=tgt_vocab.eos_id,
                                            scorer=scorer,
                                            stochastic=args.stochastic,
                                            max_length_a=args.max_length_a,
                                            max_length_b=args.max_length_b)   

    logging.info(beam_search_sampler)
    all_src_token_ids, all_src_lines = process_corpus(
        args.src_corpus,
        sentence_normalizer=src_normalizer,
        base_tokenizer=base_src_tokenizer,
        bpe_tokenizer=src_tokenizer,
        add_bos=False,
        add_eos=True
    )
    if args.tgt_corpus is not None:
        all_tgt_token_ids, all_tgt_lines = process_corpus(
            args.tgt_corpus,
            sentence_normalizer=tgt_normalizer,
            base_tokenizer=base_tgt_tokenizer,
            bpe_tokenizer=tgt_tokenizer,
            add_bos=True,
            add_eos=True
        )
    else:
        # when applying inference, populate the fake tgt tokens
        all_tgt_token_ids = all_tgt_lines = [[] for i in range(len(all_src_token_ids))]
    test_dataloader = gluon.data.DataLoader(
        list(zip(all_src_token_ids,
                 [len(ele) for ele in all_src_token_ids],
                 all_tgt_token_ids,
                 [len(ele) for ele in all_tgt_token_ids])),
        batch_size=32,
        batchify_fn=Tuple(Pad(), Stack(), Pad(), Stack()),
        shuffle=False)

    ctx = ctx_l[0]
    pred_sentences = []
    start_eval_time = time.time()
    # evaluate
    if not args.inference:
        avg_nll_loss = 0
        ntokens = 0
        for i, (src_token_ids, src_valid_length, tgt_token_ids, tgt_valid_length)\
                in enumerate(test_dataloader):
            src_token_ids = mx.np.array(src_token_ids, ctx=ctx, dtype=np.int32)
            src_valid_length = mx.np.array(src_valid_length, ctx=ctx, dtype=np.int32)
            tgt_token_ids = mx.np.array(tgt_token_ids, ctx=ctx, dtype=np.int32)
            tgt_valid_length = mx.np.array(tgt_valid_length, ctx=ctx, dtype=np.int32)
            if model.layout == 'NT':
                tgt_pred = model(src_token_ids, src_valid_length, tgt_token_ids[:, :-1],
                                tgt_valid_length - 1)
                pred_logits = mx.npx.log_softmax(tgt_pred, axis=-1)
                nll = - mx.npx.pick(pred_logits, tgt_token_ids[:, 1:])
                avg_nll_loss += mx.npx.sequence_mask(nll,
                                                     sequence_length=tgt_valid_length - 1,
                                                     use_sequence_length=True,
                                                     axis=1).sum().asnumpy()
            elif model.layout == 'TN':
                tgt_pred = model(src_token_ids.T, src_valid_length, tgt_token_ids.T[:-1, :],
                                 tgt_valid_length - 1)
                pred_logits = mx.npx.log_softmax(tgt_pred, axis=-1)
                nll = - mx.npx.pick(pred_logits, tgt_token_ids.T[1:, :])
                avg_nll_loss += mx.npx.sequence_mask(nll,
                                                     sequence_length=tgt_valid_length - 1,
                                                     use_sequence_length=True,
                                                     axis=0).sum().asnumpy()
            else:
                raise NotImplementedError
            ntokens += int((tgt_valid_length - 1).sum().asnumpy())
            init_input = mx.np.array([tgt_vocab.bos_id for _ in range(src_token_ids.shape[0])], ctx=ctx)
            if model.layout == 'NT':
                states = inference_model.init_states(src_token_ids, src_valid_length)
            elif model.layout == 'TN':
                states = inference_model.init_states(src_token_ids.T, src_valid_length)
            else:
                raise NotImplementedError
            samples, scores, valid_length = beam_search_sampler(init_input, states, src_valid_length)
            for j in range(samples.shape[0]):
                pred_tok_ids = samples[j, 0, :valid_length[j, 0].asnumpy()].asnumpy().tolist()
                bpe_decode_line = tgt_tokenizer.decode(pred_tok_ids[1:-1])
                pred_sentence = base_tgt_tokenizer.decode(bpe_decode_line.split(' '))
                pred_sentences.append(pred_sentence)
                print(pred_sentence)
            print('Processed {}/{}'.format(len(pred_sentences), len(all_tgt_lines)))
        end_eval_time = time.time()
        avg_nll_loss = avg_nll_loss / ntokens

        with open(os.path.join(args.save_dir, 'gt_sentences.txt'), 'w', encoding='utf-8') as of:
            of.write('\n'.join(all_tgt_lines))
            of.write('\n')
        with open(os.path.join(args.save_dir, 'pred_sentences.txt'), 'w', encoding='utf-8') as of:
            of.write('\n'.join(pred_sentences))
            of.write('\n')

        sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences,
                                              ref_streams=[all_tgt_lines])
        logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} '
                     '({:2.1f} {:2.1f} {:2.1f} {:2.1f}) '
                     '(BP={:.3f}, ratio={:.3f}, syslen={}, reflen={}), '
                     'Avg NLL={}, Perplexity={}'
                     .format(end_eval_time - start_eval_time, len(all_tgt_lines),
                             sacrebleu_out.score,
                             *sacrebleu_out.precisions,
                             sacrebleu_out.bp, sacrebleu_out.sys_len / sacrebleu_out.ref_len,
                             sacrebleu_out.sys_len, sacrebleu_out.ref_len,
                             avg_nll_loss, np.exp(avg_nll_loss)))
        results = {'sacrebleu': sacrebleu_out.score,
                   'nll': avg_nll_loss}
        with open(os.path.join(args.save_dir, 'results.json'), 'w') as of:
            json.dump(results, of)
    # inference only
    else:
        with open(os.path.join(args.save_dir, 'pred_sentences.txt'), 'w', encoding='utf-8') as of:
            processed_sentences = 0
            for src_token_ids, src_valid_length, _, _ in tqdm(test_dataloader):
                src_token_ids = mx.np.array(src_token_ids, ctx=ctx, dtype=np.int32)
                src_valid_length = mx.np.array(src_valid_length, ctx=ctx, dtype=np.int32)
                init_input = mx.np.array([tgt_vocab.bos_id for _ in range(src_token_ids.shape[0])], ctx=ctx)
                if model.layout == 'NT':
                    states = inference_model.init_states(src_token_ids, src_valid_length)
                elif model.layout == 'TN':
                    states = inference_model.init_states(src_token_ids.T, src_valid_length)
                else:
                    raise NotImplementedError
                samples, scores, valid_length = beam_search_sampler(init_input, states, src_valid_length)
                for j in range(samples.shape[0]):
                    pred_tok_ids = samples[j, 0, :valid_length[j, 0].asnumpy()].asnumpy().tolist()
                    bpe_decode_line = tgt_tokenizer.decode(pred_tok_ids[1:-1])
                    pred_sentence = base_tgt_tokenizer.decode(bpe_decode_line.split(' '))
                    pred_sentences.append(pred_sentence)
                of.write('\n'.join(pred_sentences))
                of.write('\n')
                processed_sentences += len(pred_sentences)
                pred_sentences = []
        end_eval_time = time.time()
        logging.info('Time Spent: {}, Inferred sentences: {}'
                     .format(end_eval_time - start_eval_time, processed_sentences))