def evaluate(data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx, log_interval, dtype): """Evaluation function.""" mlm_metric = MaskedAccuracy() nsp_metric = MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() eval_begin_time = time.time() begin_time = time.time() step_num = 0 running_mlm_loss = running_nsp_loss = 0 total_mlm_loss = total_nsp_loss = 0 running_num_tks = 0 for _, dataloader in enumerate(data_eval): for _, data_batch in enumerate(dataloader): step_num += 1 data_list = split_and_load(data_batch, ctx) loss_list = [] ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] for data in data_list: out = forward(data, model, mlm_loss, nsp_loss, vocab_size, dtype) (ls, next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length) = out loss_list.append(ls) ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) running_mlm_loss += ls1.as_in_context(mx.cpu()) running_nsp_loss += ls2.as_in_context(mx.cpu()) running_num_tks += valid_length.sum().as_in_context(mx.cpu()) nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if (step_num + 1) % (log_interval) == 0: total_mlm_loss += running_mlm_loss total_nsp_loss += running_nsp_loss log(begin_time, running_num_tks, running_mlm_loss, running_nsp_loss, step_num, mlm_metric, nsp_metric, None, log_interval) begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 mlm_metric.reset_local() nsp_metric.reset_local() mx.nd.waitall() eval_end_time = time.time() # accumulate losses from last few batches, too if running_mlm_loss != 0: total_mlm_loss += running_mlm_loss total_nsp_loss += running_nsp_loss total_mlm_loss /= step_num total_nsp_loss /= step_num logging.info( 'mlm_loss={:.3f}\tmlm_acc={:.1f}\tnsp_loss={:.3f}\tnsp_acc={:.1f}\t'. format(total_mlm_loss.asscalar(), mlm_metric.get_global()[1] * 100, total_nsp_loss.asscalar(), nsp_metric.get_global()[1] * 100)) logging.info('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))
def train(data_train, dataset_eval, model, teacher_model, mlm_loss, teacher_ce_loss, teacher_mse_loss, vocab_size, ctx, teacher_ce_weight, distillation_temperature, mlm_weight, log_tb): """Training function.""" params = model.collect_params() if params is not None: hvd.broadcast_parameters(params, root_rank=0) mlm_metric = MaskedAccuracy() mlm_metric.reset() logging.debug('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = {'scale_window': 2000 / num_workers} else: loss_scale_param = None trainer = hvd.DistributedTrainer(params, 'bertadam', optim_params) if args.dtype == 'float16': fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) trainer_step = lambda: fp16_trainer.step(1, max_norm=1 * num_workers) else: trainer_step = lambda: trainer.step(1) if args.start_step: out_dir = os.path.join(args.ckpt_dir, f"checkpoint_{args.start_step}") state_path = os.path.join( out_dir, '%07d.states.%02d' % (args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [ p for p in model.collect_params().values() if p.grad_req != 'null' ] param_dict = model.collect_params() # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_teacher_ce_loss, running_teacher_mse_loss = 0, 0, 0 running_num_tks = 0 batch_num = 0 step_num = args.start_step logging.debug('Training started') pbar = tqdm(total=num_train_steps, desc="Training:") while step_num < num_train_steps: for raw_batch_num, data_batch in enumerate(data_train): sys.stdout.flush() if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # if accumulate > 1, grad_req is set to 'add', and zero_grad is # required if accumulate > 1: param_dict.zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) # load data if args.use_avg_len: data_list = [[[s.as_in_context(context) for s in seq] for seq in shard] for context, shard in zip([ctx], data_batch)] else: data_list = list(split_and_load(data_batch, [ctx])) #data = data_list[0] data = data_list # forward with mx.autograd.record(): (loss_val, ns_label, classified, masked_id, decoded, masked_weight, mlm_loss_val, teacher_ce_loss_val, teacher_mse_loss_val, valid_len) = forward( data, model, mlm_loss, vocab_size, args.dtype, mlm_weight=mlm_weight, teacher_ce_loss=teacher_ce_loss, teacher_mse_loss=teacher_mse_loss, teacher_model=teacher_model, teacher_ce_weight=teacher_ce_weight, distillation_temperature=distillation_temperature) loss_val = loss_val / accumulate # backward if args.dtype == 'float16': fp16_trainer.backward(loss_val) else: loss_val.backward() running_mlm_loss += mlm_loss_val.as_in_context(mx.cpu()) running_teacher_ce_loss += teacher_ce_loss_val.as_in_context( mx.cpu()) running_teacher_mse_loss += teacher_mse_loss_val.as_in_context( mx.cpu()) running_num_tks += valid_len.sum().as_in_context(mx.cpu()) # update if (batch_num + 1) % accumulate == 0: # step() performs 3 things: # 1. allreduce gradients from all workers # 2. checking the global_norm of gradients and clip them if necessary # 3. averaging the gradients and apply updates trainer_step() mlm_metric.update([masked_id], [decoded], [masked_weight]) # logging if step_num % args.log_interval == 0 and batch_num % accumulate == 0: log("train ", begin_time, running_num_tks, running_mlm_loss / accumulate, running_teacher_ce_loss / accumulate, running_teacher_mse_loss / accumulate, step_num, mlm_metric, trainer, args.log_interval, model=model, log_tb=log_tb, is_master_node=is_master_node) begin_time = time.time() running_mlm_loss = running_teacher_ce_loss = running_teacher_mse_loss = running_num_tks = 0 mlm_metric.reset_local() # saving checkpoints if step_num % args.ckpt_interval == 0 and batch_num % accumulate == 0: if is_master_node: out_dir = os.path.join(args.ckpt_dir, f"checkpoint_{step_num}") if not os.path.isdir(out_dir): nlp.utils.mkdir(out_dir) save_states(step_num, trainer, out_dir, local_rank) if local_rank == 0: save_parameters(step_num, model, out_dir) if data_eval: dataset_eval = get_pretrain_data_npz( data_eval, args.batch_size_eval, 1, False, False, 1) evaluate(dataset_eval, model, mlm_loss, len(vocab), [ctx], args.log_interval, args.dtype, mlm_weight=mlm_weight, teacher_ce_loss=teacher_ce_loss, teacher_mse_loss=teacher_mse_loss, teacher_model=teacher_model, teacher_ce_weight=teacher_ce_weight, distillation_temperature=distillation_temperature, log_tb=log_tb) batch_num += 1 pbar.update(1) del data_batch if is_master_node: out_dir = os.path.join(args.ckpt_dir, "checkpoint_last") if not os.path.isdir(out_dir): os.mkdir(out_dir) save_states(step_num, trainer, out_dir, local_rank) if local_rank == 0: save_parameters(step_num, model, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() pbar.close() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx, store): """Training function.""" mlm_metric = MaskedAccuracy() nsp_metric = MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True trainer = gluon.Trainer(model.collect_params(), 'bertadam', optim_params, update_on_kvstore=False, kvstore=store) dynamic_loss_scale = args.dtype == 'float16' fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale) if args.ckpt_dir and args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states' % args.start_step) logging.info('Loading trainer state from %s', state_path) trainer.load_states(state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in model.collect_params().values() if p.grad_req != 'null'] # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() local_mlm_loss = 0 local_nsp_loss = 0 local_num_tks = 0 batch_num = 0 step_num = args.start_step parallel_model = ParallelBERT(model, mlm_loss, nsp_loss, vocab_size, store.num_workers * accumulate, trainer=fp16_trainer) num_ctxes = len(ctx) parallel = Parallel(num_ctxes, parallel_model) while step_num < num_train_steps: for _, dataloader in enumerate(data_train): if step_num >= num_train_steps: break for _, data_batch in enumerate(dataloader): if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # zero grad model.collect_params().zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 12) if args.by_token: data_list = [[seq.as_in_context(context) for seq in shard] for context, shard in zip(ctx, data_batch)] else: if data_batch[0].shape[0] < len(ctx): continue data_list = split_and_load(data_batch, ctx) ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] # parallel forward / backward for data in data_list: parallel.put(data) for _ in range(len(ctx)): (_, next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length) = parallel.get() ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) local_mlm_loss += ls1.as_in_context(mx.cpu()) / num_ctxes local_nsp_loss += ls2.as_in_context(mx.cpu()) / num_ctxes local_num_tks += valid_length.sum().as_in_context(mx.cpu()) # update if (batch_num + 1) % accumulate == 0: fp16_trainer.step(1, max_norm=1) nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: log(begin_time, local_num_tks, local_mlm_loss / accumulate, local_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer) begin_time = time.time() local_mlm_loss = local_nsp_loss = local_num_tks = 0 mlm_metric.reset_local() nsp_metric.reset_local() # saving checkpoints if args.ckpt_dir and (step_num + 1) % (args.ckpt_interval) == 0 \ and (batch_num + 1) % accumulate == 0: save_params(step_num, args, model, trainer) batch_num += 1 save_params(step_num, args, model, trainer) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def evaluate(data_eval, model, mlm_loss, vocab_size, ctx, log_interval, dtype, mlm_weight=1.0, teacher_ce_loss=None, teacher_mse_loss=None, teacher_model=None, teacher_ce_weight=0.0, distillation_temperature=1.0, log_tb=None): """Evaluation function.""" logging.info('Running evaluation ... ') mlm_metric = MaskedAccuracy() mlm_metric.reset() eval_begin_time = time.time() begin_time = time.time() step_num = 0 running_mlm_loss = 0 total_mlm_loss = 0 running_teacher_ce_loss = running_teacher_mse_loss = 0 total_teacher_ce_loss = total_teacher_mse_loss = 0 running_num_tks = 0 for _, dataloader in tqdm(enumerate(data_eval), desc="Evaluation"): step_num += 1 data_list = [[seq.as_in_context(context) for seq in shard] for context, shard in zip(ctx, dataloader)] loss_list = [] ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] for data in data_list: out = forward(data, model, mlm_loss, vocab_size, dtype, is_eval=True, mlm_weight=mlm_weight, teacher_ce_loss=teacher_ce_loss, teacher_mse_loss=teacher_mse_loss, teacher_model=teacher_model, teacher_ce_weight=teacher_ce_weight, distillation_temperature=distillation_temperature) (loss_val, next_sentence_label, classified, masked_id, decoded, masked_weight, mlm_loss_val, teacher_ce_loss_val, teacher_mse_loss_val, valid_length) = out loss_list.append(loss_val) ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) running_mlm_loss += mlm_loss_val.as_in_context(mx.cpu()) running_num_tks += valid_length.sum().as_in_context(mx.cpu()) running_teacher_ce_loss += teacher_ce_loss_val.as_in_context( mx.cpu()) running_teacher_mse_loss += teacher_mse_loss_val.as_in_context( mx.cpu()) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if (step_num + 1) % (log_interval) == 0: total_mlm_loss += running_mlm_loss total_teacher_ce_loss += running_teacher_ce_loss total_teacher_mse_loss += running_teacher_mse_loss log("eval ", begin_time, running_num_tks, running_mlm_loss, running_teacher_ce_loss, running_teacher_mse_loss, step_num, mlm_metric, None, log_interval, model=model, log_tb=log_tb) begin_time = time.time() running_mlm_loss = running_num_tks = 0 running_teacher_ce_loss = running_teacher_mse_loss = 0 mlm_metric.reset_local() mx.nd.waitall() eval_end_time = time.time() # accumulate losses from last few batches, too if running_mlm_loss != 0: total_mlm_loss += running_mlm_loss total_teacher_ce_loss += running_teacher_ce_loss total_teacher_mse_loss += running_teacher_mse_loss total_mlm_loss /= step_num total_teacher_ce_loss /= step_num total_teacher_mse_loss /= step_num logging.info('Eval mlm_loss={:.3f}\tmlm_acc={:.1f}\tteacher_ce={:.2e}\tteacher_mse={:.2e}' .format(total_mlm_loss.asscalar(), mlm_metric.get_global()[1] * 100, total_teacher_ce_loss.asscalar(), total_teacher_mse_loss.asscalar())) logging.info('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))
def evaluate(data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx): """Evaluation function.""" mlm_metric = MaskedAccuracy() nsp_metric = MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() eval_begin_time = time.time() begin_time = time.time() step_num = 0 # Total loss for the whole dataset total_mlm_loss = total_nsp_loss = 0 # Running loss, reset when a log is emitted running_mlm_loss = running_nsp_loss = 0 running_num_tks = 0 for _, dataloader in enumerate(data_eval): for _, data in enumerate(dataloader): step_num += 1 data_list = split_and_load(data, ctx) loss_list = [] ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] # Run inference on the batch, collect the predictions and losses batch_mlm_loss = batch_nsp_loss = 0 for data in data_list: out = forward(data, model, mlm_loss, nsp_loss, vocab_size) (ls, next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length) = out loss_list.append(ls) ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) batch_mlm_loss += ls1.as_in_context(mx.cpu()) batch_nsp_loss += ls2.as_in_context(mx.cpu()) running_num_tks += valid_length.sum().as_in_context(mx.cpu()) running_mlm_loss += batch_mlm_loss running_nsp_loss += batch_nsp_loss total_mlm_loss += batch_mlm_loss total_nsp_loss += batch_nsp_loss nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # Log and reset running loss if (step_num + 1) % (args.log_interval) == 0: log(begin_time, running_num_tks, running_mlm_loss, running_nsp_loss, step_num, mlm_metric, nsp_metric, None) begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 mlm_metric.reset_running() nsp_metric.reset_running() mx.nd.waitall() eval_end_time = time.time() total_mlm_loss /= step_num total_nsp_loss /= step_num logging.info('mlm_loss={:.3f}\tmlm_acc={:.1f}\tnsp_loss={:.3f}\tnsp_acc={:.1f}\t' .format(total_mlm_loss.asscalar(), mlm_metric.get_global()[1] * 100, total_nsp_loss.asscalar(), nsp_metric.get_global()[1] * 100)) logging.info('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))