def forward_backward(self, x): """forward backward implementation""" with mx.autograd.record(): (ls, next_sentence_label, classified, masked_id, decoded, \ masked_weight, ls1, ls2, valid_length) = forward(x, self._model, self._mlm_loss, self._nsp_loss, self._vocab_size, args.dtype) ls = ls / self._rescale_factor if args.dtype == 'float16': self._trainer.backward(ls) else: ls.backward() return ls, next_sentence_label, classified, masked_id, decoded, \ masked_weight, ls1, ls2, valid_length
def train(data_train, data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx): """Training function.""" hvd.broadcast_parameters(model.collect_params(), root_rank=0) mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() logging.debug('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = {'scale_window': 2000 / num_workers} else: loss_scale_param = None trainer = hvd.DistributedTrainer(model.collect_params(), 'bertadam', optim_params) fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) if args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in model.collect_params().values() if p.grad_req != 'null'] param_dict = model.collect_params() # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_nsp_loss = 0, 0 running_num_tks = 0 batch_num = 0 step_num = args.start_step logging.debug('Training started') while step_num < num_train_steps: for _, dataloader in enumerate(data_train): if step_num >= num_train_steps: break # create dummy data loader if needed if args.dummy_data_len: target_shape = (args.batch_size, args.dummy_data_len) dataloader = get_dummy_dataloader(dataloader, target_shape) for _, data_batch in enumerate(dataloader): if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # if accumulate > 1, grad_req is set to 'add', and zero_grad is required if accumulate > 1: param_dict.zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) # load data if args.use_avg_len: data_list = [[seq.as_in_context(context) for seq in shard] for context, shard in zip([ctx], data_batch)] else: data_list = list(split_and_load(data_batch, [ctx])) data = data_list[0] # forward with mx.autograd.record(): (ls, ns_label, classified, masked_id, decoded, \ masked_weight, ls1, ls2, valid_len) = forward(data, model, mlm_loss, nsp_loss, vocab_size, args.dtype) ls = ls / accumulate # backward if args.dtype == 'float16': fp16_trainer.backward(ls) else: ls.backward() running_mlm_loss += ls1.as_in_context(mx.cpu()) running_nsp_loss += ls2.as_in_context(mx.cpu()) running_num_tks += valid_len.sum().as_in_context(mx.cpu()) # update if (batch_num + 1) % accumulate == 0: # step() performs 3 things: # 1. allreduce gradients from all workers # 2. checking the global_norm of gradients and clip them if necessary # 3. averaging the gradients and apply updates fp16_trainer.step(1, max_norm=1*num_workers) nsp_metric.update([ns_label], [classified]) mlm_metric.update([masked_id], [decoded], [masked_weight]) # logging if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 mlm_metric.reset_local() nsp_metric.reset_local() # saving checkpoints if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0: if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model, args.ckpt_dir) if data_eval: # eval data is always based on a fixed npz file. dataset_eval = get_pretrain_data_npz(data_eval, args.batch_size_eval, 1, False, False, 1) evaluate(dataset_eval, model, nsp_loss, mlm_loss, len(vocab), [ctx], args.log_interval, args.dtype) batch_num += 1 if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))