Example #1
0
def train():
    """Training loop for language model.
    """
    print(model)
    from_epoch = 0
    model.initialize(mx.init.Xavier(factor_type='out'), ctx=context)
    trainer_params = {'learning_rate': args.lr, 'wd': 0, 'eps': args.eps}
    trainer = gluon.Trainer(model.collect_params(), 'adagrad', trainer_params)
    if args.from_epoch:
        from_epoch = args.from_epoch
        checkpoint_name = '%s.%s'%(args.save, format(from_epoch - 1, '02d'))
        model.load_parameters(checkpoint_name)
        trainer.load_states('%s.state'%args.save)
        print('Loaded parameters from checkpoint %s'%(checkpoint_name))

    model.hybridize(static_alloc=True, static_shape=True)
    encoder_params = model.encoder.collect_params().values()
    embedding_params = list(model.embedding.collect_params().values())
    parallel_model = ParallelBigRNN(model, loss)
    parallel = Parallel(len(context), parallel_model)
    for epoch in range(from_epoch, args.epochs):
        sys.stdout.flush()
        total_L = 0.0
        start_epoch_time = time.time()
        start_log_interval_time = time.time()
        hiddens = [model.begin_state(batch_size=args.batch_size,
                                     func=mx.nd.zeros, ctx=ctx) for ctx in context]
        nbatch = 0
        has_next = True
        train_data_iter = iter(train_data)
        data, target, mask, sample = next(train_data_iter)

        while has_next:
            nbatch += 1
            hiddens = detach(hiddens)
            Ls = []
            for _, batch in enumerate(zip(data, target, mask, sample, hiddens)):
                parallel.put(batch)

            for _ in range(len(data)):
                hidden, ls = parallel.get()
                # hidden states are ordered by context id
                index = context.index(hidden[0].context)
                hiddens[index] = hidden
                Ls.append(ls)

            # prefetch the next batch of data
            try:
                data, target, mask, sample = next(train_data_iter)
            except StopIteration:
                has_next = False

            # rescale embedding grad
            for ctx in context:
                x = embedding_params[0].grad(ctx)
                x[:] *= args.batch_size
                encoder_grad = [p.grad(ctx) for p in encoder_params]
                # perform gradient clipping per ctx
                gluon.utils.clip_global_norm(encoder_grad, args.clip)

            trainer.step(len(context))

            total_L += sum([mx.nd.sum(L).asscalar() / args.bptt for L in Ls])

            if nbatch % args.log_interval == 0:
                cur_L = total_L / args.log_interval / len(context)
                ppl = math.exp(cur_L) if cur_L < 100 else float('inf')
                print('[Epoch %d Batch %d] loss %.2f, ppl %.2f, '
                      'throughput %.2f samples/s'
                      %(epoch, nbatch, cur_L, ppl,
                        train_batch_size*args.log_interval/(time.time()-start_log_interval_time)))
                total_L = 0.0
                start_log_interval_time = time.time()
                sys.stdout.flush()

        end_epoch_time = time.time()
        print('Epoch %d took %.2f seconds.'%(epoch, end_epoch_time - start_epoch_time))
        mx.nd.waitall()
        checkpoint_name = '%s.%s'%(args.save, format(epoch, '02d'))
        model.save_parameters(checkpoint_name)
        trainer.save_states('%s.state'%args.save)
def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx, store):
    """Training function."""
    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    trainer = gluon.Trainer(model.collect_params(),
                            'bertadam',
                            optim_params,
                            update_on_kvstore=False,
                            kvstore=store)
    dynamic_loss_scale = args.dtype == 'float16'
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale)

    if args.ckpt_dir and args.start_step:
        state_path = os.path.join(args.ckpt_dir,
                                  '%07d.states' % args.start_step)
        logging.info('Loading trainer state from %s', state_path)
        trainer.load_states(state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [
        p for p in model.collect_params().values() if p.grad_req != 'null'
    ]

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    for p in params:
        p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    local_mlm_loss = 0
    local_nsp_loss = 0
    local_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    parallel_model = ParallelBERT(model,
                                  mlm_loss,
                                  nsp_loss,
                                  vocab_size,
                                  store.num_workers * accumulate,
                                  trainer=fp16_trainer)
    num_ctxes = len(ctx)
    parallel = Parallel(num_ctxes, parallel_model)

    while step_num < num_train_steps:
        for _, dataloader in enumerate(data_train):
            if step_num >= num_train_steps:
                break
            for _, data_batch in enumerate(dataloader):
                if step_num >= num_train_steps:
                    break
                if batch_num % accumulate == 0:
                    step_num += 1
                    # zero grad
                    model.collect_params().zero_grad()
                    # update learning rate
                    if step_num <= num_warmup_steps:
                        new_lr = lr * step_num / num_warmup_steps
                    else:
                        offset = lr * step_num / num_train_steps
                        new_lr = lr - offset
                    trainer.set_learning_rate(new_lr)
                    if args.profile:
                        profile(step_num, 10, 12)
                if args.by_token:
                    data_list = [[seq.as_in_context(context) for seq in shard]
                                 for context, shard in zip(ctx, data_batch)]
                else:
                    if data_batch[0].shape[0] < len(ctx):
                        continue
                    data_list = split_and_load(data_batch, ctx)

                ns_label_list, ns_pred_list = [], []
                mask_label_list, mask_pred_list, mask_weight_list = [], [], []

                # parallel forward / backward
                for data in data_list:
                    parallel.put(data)
                for _ in range(len(ctx)):
                    (_, next_sentence_label, classified, masked_id, decoded,
                     masked_weight, ls1, ls2, valid_length) = parallel.get()
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(classified)
                    mask_label_list.append(masked_id)
                    mask_pred_list.append(decoded)
                    mask_weight_list.append(masked_weight)
                    local_mlm_loss += ls1.as_in_context(mx.cpu()) / num_ctxes
                    local_nsp_loss += ls2.as_in_context(mx.cpu()) / num_ctxes
                    local_num_tks += valid_length.sum().as_in_context(mx.cpu())

                # update
                if (batch_num + 1) % accumulate == 0:
                    fp16_trainer.step(1, max_norm=1)
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list,
                                  mask_weight_list)
                # logging
                if (step_num + 1) % (args.log_interval) == 0 and (
                        batch_num + 1) % accumulate == 0:
                    log(begin_time, local_num_tks, local_mlm_loss / accumulate,
                        local_nsp_loss / accumulate, step_num, mlm_metric,
                        nsp_metric, trainer)
                    begin_time = time.time()
                    local_mlm_loss = local_nsp_loss = local_num_tks = 0
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()

                # saving checkpoints
                if args.ckpt_dir and (step_num + 1) % (args.ckpt_interval) == 0 \
                   and (batch_num + 1) % accumulate == 0:
                    save_params(step_num, args, model, trainer)
                batch_num += 1
    save_params(step_num, args, model, trainer)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time -
                                             train_begin_time))