Пример #1
0
def test_average_sgd_tracker():
    samples = [mx.np.random.normal(0, 1, (10, 3)) for _ in range(10)]
    no_moving_avg_param_l = []
    with_moving_avg_param_l = []
    moving_avg_param = None
    net_final_moving_avg_param = None
    for use_moving_avg in [False, True]:
        net = HybridSequential()
        net.add(nn.Dense(10), nn.Dense(3))
        net.initialize(init=mx.init.One())
        net.hybridize()
        trainer = mx.gluon.Trainer(net.collect_params(), 'adam')
        if use_moving_avg:
            model_averager = AverageSGDTracker(net.collect_params())
        for sample in samples:
            out = sample ** 3 + sample
            with mx.autograd.record():
                pred = net(sample)
                loss = ((out - pred) ** 2).mean()
            loss.backward()
            trainer.step(1.0)
            if use_moving_avg:
                model_averager.step()
                print(model_averager.average_params)
            if use_moving_avg:
                with_moving_avg_param_l.append({k: v.data().asnumpy() for k, v in net.collect_params().items()})
            else:
                no_moving_avg_param_l.append({k: v.data().asnumpy() for k, v in net.collect_params().items()})
        if use_moving_avg:
            model_averager.copy_back()
            moving_avg_param = {k: v.asnumpy() for k, v in model_averager.average_params.items()}
            net_final_moving_avg_param = {k: v.data().asnumpy() for k, v in net.collect_params().items()}
    # Match the recorded params
    calculated_moving_param = {k: np.zeros(v.shape) for k, v in no_moving_avg_param_l[0].items()}
    for step, (no_moving_avg_param, with_moving_avg_param) in enumerate(zip(no_moving_avg_param_l,
                                                                            with_moving_avg_param_l)):
        decay = 1.0 / (step + 1)
        assert len(no_moving_avg_param) == len(with_moving_avg_param)
        for k in with_moving_avg_param:
            assert_allclose(no_moving_avg_param[k], with_moving_avg_param[k])
            calculated_moving_param[k] += decay * (with_moving_avg_param[k] - calculated_moving_param[k])
    assert len(moving_avg_param) == len(net_final_moving_avg_param)
    for k in moving_avg_param:
        assert_allclose(moving_avg_param[k], calculated_moving_param[k], 1E-5, 1E-5)
        assert_allclose(moving_avg_param[k], net_final_moving_avg_param[k], 1E-5, 1E-5)
Пример #2
0
def train(args):
    store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    src_tokenizer = create_tokenizer(args.src_tokenizer,
                                     args.src_subword_model_path,
                                     args.src_vocab_path)
    tgt_tokenizer = create_tokenizer(args.tgt_tokenizer,
                                     args.tgt_subword_model_path,
                                     args.tgt_vocab_path)
    src_vocab = src_tokenizer.vocab
    tgt_vocab = tgt_tokenizer.vocab
    train_src_data, train_tgt_data = load_dataset_with_cache(
        args.train_src_corpus, args.train_tgt_corpus, src_tokenizer,
        tgt_tokenizer, args.overwrite_cache)
    dev_src_data, dev_tgt_data = load_dataset_with_cache(
        args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer,
        args.overwrite_cache)
    data_train = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))
    ])
    data_val = gluon.data.SimpleDataset([
        (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
        for i, (src_tokens,
                tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))
    ])
    # Construct the model + loss function
    if args.cfg.endswith('.yml'):
        cfg = TransformerModel.get_cfg().clone_merge(args.cfg)
    else:
        cfg = TransformerModel.get_cfg(args.cfg)
    cfg.defrost()
    cfg.MODEL.src_vocab_size = len(src_vocab)
    cfg.MODEL.tgt_vocab_size = len(tgt_vocab)
    if args.fp16:
        raise NotImplementedError


#        cfg.MODEL.dtype = 'float16'
    cfg.freeze()
    model = TransformerModel.from_cfg(cfg)
    model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l)
    model.hybridize()
    if local_rank == 0:
        logging.info(model)
    with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f:
        cfg_f.write(cfg.dump())
    label_smooth_loss = LabelSmoothCrossEntropyLoss(
        num_labels=len(tgt_vocab),
        alpha=args.label_smooth_alpha,
        from_logits=False)
    label_smooth_loss.hybridize()
    rescale_loss = 100.0

    if args.comm_backend == 'horovod':
        hvd.broadcast_parameters(model.collect_params(), root_rank=0)

    # Construct the trainer
    # TODO(sxjscience) Support AMP
    if args.lr is None:
        base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt(
            args.warmup_steps)
    else:
        base_lr = args.lr
    lr_scheduler = InverseSquareRootScheduler(
        warmup_steps=args.warmup_steps,
        base_lr=base_lr,
        warmup_init_lr=args.warmup_init_lr)
    trainer_settings = (model.collect_params(), 'adam', {
        'learning_rate': args.lr,
        'beta1': 0.9,
        'beta2': 0.98,
        'epsilon': 1e-9,
        'lr_scheduler': lr_scheduler
    })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(*trainer_settings)
    else:
        trainer = gluon.Trainer(*trainer_settings)
    # Load Data
    if args.sampler == 'BoundedBudgetSampler':
        train_batch_sampler = BoundedBudgetSampler(
            lengths=[(ele[2], ele[3]) for ele in data_train],
            max_num_tokens=args.max_num_tokens,
            max_num_sentences=args.max_num_sentences,
            seed=args.seed,
            num_parts=num_parts,
            part_index=rank)
    elif args.sampler == 'FixedBucketSampler':
        if args.comm_backend == 'horovod':
            raise NotImplementedError(
                'FixedBucketSampler does not support horovod at present')

        if args.bucket_scheme == 'constant':
            bucket_scheme = ConstWidthBucket()
        elif args.bucket_scheme == 'linear':
            bucket_scheme = LinearWidthBucket()
        elif args.bucket_scheme == 'exp':
            bucket_scheme = ExpWidthBucket(bucket_len_step=1.2)
        else:
            raise NotImplementedError
        # TODO(sxjscience) Support auto-bucket-size tuning
        train_batch_sampler = FixedBucketSampler(lengths=[
            (ele[2], ele[3]) for ele in data_train
        ],
                                                 batch_size=args.batch_size,
                                                 num_buckets=args.num_buckets,
                                                 ratio=args.bucket_ratio,
                                                 shuffle=True,
                                                 use_average_length=True,
                                                 bucket_scheme=bucket_scheme,
                                                 seed=args.seed)
    else:
        raise NotImplementedError

    if local_rank == 0:
        logging.info(train_batch_sampler)

    batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(),
                           bf.Stack())
    train_data_loader = gluon.data.DataLoader(
        data_train,
        batch_sampler=train_batch_sampler,
        batchify_fn=batchify_fn,
        num_workers=0)

    val_data_loader = gluon.data.DataLoader(data_val,
                                            batch_size=args.val_batch_size,
                                            batchify_fn=batchify_fn,
                                            num_workers=0,
                                            shuffle=False)
    for v in model.collect_params().values():
        if v.grad_req != 'null':
            v.grad_req = 'add'
    model.zero_grad()
    model_averager = AverageSGDTracker(model.collect_params())
    log_start_time = time.time()
    num_params, num_fixed_params = None, None
    # TODO(sxjscience) Add a log metric class
    accum_count = 0
    loss_denom = 0
    n_train_iters = 0
    log_wc = 0
    log_avg_loss = 0.0
    log_loss_denom = 0
    epoch_id = 0
    while (args.epochs < 0 or epoch_id < args.epochs
           ):  # when args.epochs < 0, the model will keep training
        n_epoch_train_iters = 0
        processed_batch_num = 0
        train_multi_data_loader = grouper(train_data_loader, len(ctx_l))
        is_last_batch = False
        sample_data_l = next(train_multi_data_loader)
        while not is_last_batch:
            processed_batch_num += len(sample_data_l)
            loss_l = []
            for sample_data, ctx in zip(sample_data_l, ctx_l):
                if sample_data is None:
                    continue
                src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data
                src_wc, tgt_wc, bs = src_valid_length.sum(
                ), tgt_valid_length.sum(), src_token_ids.shape[0]
                loss_denom += tgt_wc - bs
                log_loss_denom += tgt_wc - bs
                log_wc += src_wc + tgt_wc
                src_token_ids = src_token_ids.as_in_ctx(ctx)
                tgt_token_ids = tgt_token_ids.as_in_ctx(ctx)
                src_valid_length = src_valid_length.as_in_ctx(ctx)
                tgt_valid_length = tgt_valid_length.as_in_ctx(ctx)
                with mx.autograd.record():
                    tgt_pred = model(src_token_ids, src_valid_length,
                                     tgt_token_ids[:, :-1],
                                     tgt_valid_length - 1)
                    tgt_labels = tgt_token_ids[:, 1:]
                    loss = label_smooth_loss(tgt_pred, tgt_labels)
                    loss = mx.npx.sequence_mask(
                        loss,
                        sequence_length=tgt_valid_length - 1,
                        use_sequence_length=True,
                        axis=1)
                    loss_l.append(loss.sum() / rescale_loss)
            for l in loss_l:
                l.backward()
            accum_count += 1
            try:
                sample_data_l = next(train_multi_data_loader)
            except StopIteration:
                is_last_batch = True
            if local_rank == 0 and num_params is None:
                num_params, num_fixed_params = count_parameters(
                    model.collect_params())
                logging.info(
                    'Total Number of Parameters (not-fixed/fixed): {}/{}'.
                    format(num_params, num_fixed_params))
            sum_loss = sum([l.as_in_ctx(mx.cpu())
                            for l in loss_l]) * rescale_loss
            log_avg_loss += sum_loss
            mx.npx.waitall()
            if accum_count == args.num_accumulated or is_last_batch:
                # Update the parameters
                n_train_iters += 1
                n_epoch_train_iters += 1
                trainer.step(loss_denom.asnumpy() / rescale_loss)
                accum_count = 0
                loss_denom = 0
                model.zero_grad()
                if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \
                   (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update):
                    model_averager.step()
                if local_rank == 0 and \
                   (n_epoch_train_iters % args.log_interval == 0 or is_last_batch):
                    log_end_time = time.time()
                    log_wc = log_wc.asnumpy()
                    wps = log_wc / (log_end_time - log_start_time)
                    log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy()
                    logging.info(
                        '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
                        'throughput={:.2f}K wps, wc={:.2f}K, LR={}'.format(
                            epoch_id, processed_batch_num * num_parts,
                            len(train_data_loader), log_avg_loss,
                            np.exp(log_avg_loss), wps / 1000, log_wc / 1000,
                            trainer.learning_rate))
                    log_start_time = time.time()
                    log_avg_loss = 0
                    log_loss_denom = 0
                    log_wc = 0
                if local_rank == 0 and \
                   (args.max_update > 0 and n_train_iters % args.save_interval_update == 0):
                    model.save_parameters(os.path.join(
                        args.save_dir, 'update{:d}.params'.format(
                            n_train_iters // args.save_interval_update)),
                                          deduplicate=True)
                if args.max_update > 0 and n_train_iters >= args.max_update:
                    break
        if local_rank == 0 and args.epochs > 0:
            model.save_parameters(os.path.join(
                args.save_dir, 'epoch{:d}.params'.format(epoch_id)),
                                  deduplicate=True)
        avg_valid_loss = validation(model, val_data_loader, ctx_l)
        logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'.format(
            epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

        if args.max_update > 0 and n_train_iters >= args.max_update:
            break
        epoch_id += 1

    if args.num_averages > 0:
        model_averager.copy_back(
            model.collect_params())  # TODO(sxjscience) Rewrite using update
        model.save_parameters(os.path.join(args.save_dir, 'average.params'),
                              deduplicate=True)