def train(): """Training loop for language model. """ print(model) from_epoch = 0 model.initialize(mx.init.Xavier(factor_type='out'), ctx=context) trainer_params = {'learning_rate': args.lr, 'wd': 0, 'eps': args.eps} trainer = gluon.Trainer(model.collect_params(), 'adagrad', trainer_params) if args.from_epoch: from_epoch = args.from_epoch checkpoint_name = '%s.%s'%(args.save, format(from_epoch - 1, '02d')) model.load_parameters(checkpoint_name) trainer.load_states('%s.state'%args.save) print('Loaded parameters from checkpoint %s'%(checkpoint_name)) model.hybridize(static_alloc=True, static_shape=True) encoder_params = model.encoder.collect_params().values() embedding_params = list(model.embedding.collect_params().values()) parallel_model = ParallelBigRNN(model, loss) parallel = Parallel(len(context), parallel_model) for epoch in range(from_epoch, args.epochs): sys.stdout.flush() total_L = 0.0 start_epoch_time = time.time() start_log_interval_time = time.time() hiddens = [model.begin_state(batch_size=args.batch_size, func=mx.nd.zeros, ctx=ctx) for ctx in context] nbatch = 0 has_next = True train_data_iter = iter(train_data) data, target, mask, sample = next(train_data_iter) while has_next: nbatch += 1 hiddens = detach(hiddens) Ls = [] for _, batch in enumerate(zip(data, target, mask, sample, hiddens)): parallel.put(batch) for _ in range(len(data)): hidden, ls = parallel.get() # hidden states are ordered by context id index = context.index(hidden[0].context) hiddens[index] = hidden Ls.append(ls) # prefetch the next batch of data try: data, target, mask, sample = next(train_data_iter) except StopIteration: has_next = False # rescale embedding grad for ctx in context: x = embedding_params[0].grad(ctx) x[:] *= args.batch_size encoder_grad = [p.grad(ctx) for p in encoder_params] # perform gradient clipping per ctx gluon.utils.clip_global_norm(encoder_grad, args.clip) trainer.step(len(context)) total_L += sum([mx.nd.sum(L).asscalar() / args.bptt for L in Ls]) if nbatch % args.log_interval == 0: cur_L = total_L / args.log_interval / len(context) ppl = math.exp(cur_L) if cur_L < 100 else float('inf') print('[Epoch %d Batch %d] loss %.2f, ppl %.2f, ' 'throughput %.2f samples/s' %(epoch, nbatch, cur_L, ppl, train_batch_size*args.log_interval/(time.time()-start_log_interval_time))) total_L = 0.0 start_log_interval_time = time.time() sys.stdout.flush() end_epoch_time = time.time() print('Epoch %d took %.2f seconds.'%(epoch, end_epoch_time - start_epoch_time)) mx.nd.waitall() checkpoint_name = '%s.%s'%(args.save, format(epoch, '02d')) model.save_parameters(checkpoint_name) trainer.save_states('%s.state'%args.save)
def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx, store): """Training function.""" mlm_metric = MaskedAccuracy() nsp_metric = MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True trainer = gluon.Trainer(model.collect_params(), 'bertadam', optim_params, update_on_kvstore=False, kvstore=store) dynamic_loss_scale = args.dtype == 'float16' fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale) if args.ckpt_dir and args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states' % args.start_step) logging.info('Loading trainer state from %s', state_path) trainer.load_states(state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [ p for p in model.collect_params().values() if p.grad_req != 'null' ] # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() local_mlm_loss = 0 local_nsp_loss = 0 local_num_tks = 0 batch_num = 0 step_num = args.start_step parallel_model = ParallelBERT(model, mlm_loss, nsp_loss, vocab_size, store.num_workers * accumulate, trainer=fp16_trainer) num_ctxes = len(ctx) parallel = Parallel(num_ctxes, parallel_model) while step_num < num_train_steps: for _, dataloader in enumerate(data_train): if step_num >= num_train_steps: break for _, data_batch in enumerate(dataloader): if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # zero grad model.collect_params().zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 12) if args.by_token: data_list = [[seq.as_in_context(context) for seq in shard] for context, shard in zip(ctx, data_batch)] else: if data_batch[0].shape[0] < len(ctx): continue data_list = split_and_load(data_batch, ctx) ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] # parallel forward / backward for data in data_list: parallel.put(data) for _ in range(len(ctx)): (_, next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length) = parallel.get() ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) local_mlm_loss += ls1.as_in_context(mx.cpu()) / num_ctxes local_nsp_loss += ls2.as_in_context(mx.cpu()) / num_ctxes local_num_tks += valid_length.sum().as_in_context(mx.cpu()) # update if (batch_num + 1) % accumulate == 0: fp16_trainer.step(1, max_norm=1) nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if (step_num + 1) % (args.log_interval) == 0 and ( batch_num + 1) % accumulate == 0: log(begin_time, local_num_tks, local_mlm_loss / accumulate, local_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer) begin_time = time.time() local_mlm_loss = local_nsp_loss = local_num_tks = 0 mlm_metric.reset_local() nsp_metric.reset_local() # saving checkpoints if args.ckpt_dir and (step_num + 1) % (args.ckpt_interval) == 0 \ and (batch_num + 1) % accumulate == 0: save_params(step_num, args, model, trainer) batch_num += 1 save_params(step_num, args, model, trainer) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))