def setup_logging(args, local_rank): """ Setup logging configuration as well as random seed """ logging_config(args.output_dir, name='finetune_squad{}'.format(args.version),# avoid race overwrite_handler=True, console=(local_rank == 0)) logging.info(args) set_seed(args.seed) logging.debug('Random seed set to {}'.format(args.seed))
def train(args): _, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) level = logging.DEBUG if args.verbose else logging.INFO logging_config( args.ckpt_dir, name='pretrain_bert_' + str(rank), # avoid race level=level, console=(local_rank == 0)) logging.info(args) logging.debug('Random seed set to {}'.format(args.seed)) set_seed(args.seed) logging.info('Training info: num_buckets: {}, ' 'num_workers: {}, rank: {}'.format(args.num_buckets, num_workers, rank)) cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) parameters_option(args.start_step, model, args.ckpt_dir, 'Loading', ctx_l) else: model.initialize(ctx=ctx_l) model.hybridize() if args.raw: get_dataset_fn = functools.partial( get_pretrain_data_text, max_seq_length=args.max_seq_length, short_seq_prob=args.short_seq_prob, masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, whole_word_mask=args.whole_word_mask, random_next_sentence=args.random_next_sentence, tokenizer=tokenizer, circle_length=args.circle_length, repeat=args.repeat, dataset_cached=args.dataset_cached, num_max_dataset_cached=args.num_max_dataset_cached) else: get_dataset_fn = get_pretrain_data_npz data_train = get_dataset_fn(args.data, args.batch_size, shuffle=True, num_buckets=args.num_buckets, vocab=tokenizer.vocab, num_parts=num_workers, part_idx=rank, num_dataset_workers=args.num_dataset_workers, num_batch_workers=args.num_batch_workers) param_dict = model.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Set grad_req if gradient accumulation is required params = [p for p in param_dict.values() if p.grad_req != 'null'] num_accumulated = args.num_accumulated if num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' num_steps = args.num_steps warmup_steps = int(num_steps * args.warmup_ratio) log_interval = args.log_interval save_interval = args.ckpt_interval logging.info( '#Total Training Steps={}, Warmup Steps={}, Save Interval={}'.format( num_steps, warmup_steps, save_interval)) optimizer_params = {'learning_rate': args.lr, 'wd': args.wd} if args.optimizer == 'adamw': optimizer_params.update({ 'beta1': 0.9, 'beta2': 0.999, 'epsilon': 1e-6, 'correct_bias': False, }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) elif args.comm_backend == 'byteps': trainer = bps.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) states_option(args.start_step, trainer, args.ckpt_dir, local_rank, 'Loading') # backend specific implementation if args.comm_backend == 'byteps': trainer._init_params() if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) # prepare the loss function nsp_loss_fn = mx.gluon.loss.SoftmaxCELoss() mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss() nsp_loss_fn.hybridize() mlm_loss_fn.hybridize() mlm_metric = MaskedAccuracy() nsp_metric = MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() step_num = args.start_step if args.phase2: step_num -= args.phase1_num_steps running_mlm_loss, running_nsp_loss = 0., 0. running_num_tks = 0 train_start_time = time.time() tic = time.time() # start training train_loop_dataloader = grouper(repeat(data_train), len(ctx_l)) while step_num < num_steps: step_num += 1 for _ in range(num_accumulated): sample_l = next(train_loop_dataloader) mlm_loss_l = [] nsp_loss_l = [] loss_l = [] ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] for sample, ctx in zip(sample_l, ctx_l): # prepare data (input_id, masked_id, masked_position, masked_weight, \ next_sentence_label, segment_id, valid_length) = sample input_id = input_id.as_in_ctx(ctx) masked_id = masked_id.as_in_ctx(ctx) masked_position = masked_position.as_in_ctx(ctx) masked_weight = masked_weight.as_in_ctx(ctx) next_sentence_label = next_sentence_label.as_in_ctx(ctx) segment_id = segment_id.as_in_ctx(ctx) valid_length = valid_length.as_in_ctx(ctx) with mx.autograd.record(): _, _, nsp_score, mlm_scores = model( input_id, segment_id, valid_length, masked_position) denominator = (masked_weight.sum() + 1e-8) * num_accumulated * len(ctx_l) mlm_scores_r = mx.npx.reshape(mlm_scores, (-5, -1)) masked_id_r = masked_id.reshape((-1, )) mlm_loss = mlm_loss_fn(mlm_scores_r, masked_id_r, masked_weight.reshape( (-1, 1))).sum() / denominator denominator = num_accumulated * len(ctx_l) nsp_loss = nsp_loss_fn( nsp_score, next_sentence_label).mean() / denominator mlm_loss_l.append(mlm_loss) nsp_loss_l.append(nsp_loss) loss_l.append(mlm_loss + nsp_loss) mask_label_list.append(masked_id_r) mask_pred_list.append(mlm_scores_r) mask_weight_list.append(masked_weight.reshape((-1, ))) ns_label_list.append(next_sentence_label) ns_pred_list.append(nsp_score) running_num_tks += valid_length.sum().as_in_ctx(mx.cpu()) for loss in loss_l: loss.backward() running_mlm_loss += sum([ ele.as_in_ctx(mx.cpu()) for ele in mlm_loss_l ]).asnumpy().item() running_nsp_loss += sum([ ele.as_in_ctx(mx.cpu()) for ele in nsp_loss_l ]).asnumpy().item() mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) nsp_metric.update(ns_label_list, ns_pred_list) # update trainer.allreduce_grads() total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_workers) total_norm = total_norm / num_workers # update learning rate scheduled_lr = args.lr if step_num <= warmup_steps: scheduled_lr *= step_num / warmup_steps else: offset = (num_steps - step_num) / (num_steps - warmup_steps) scheduled_lr *= max(offset, 0) trainer.set_learning_rate(scheduled_lr) if args.comm_backend == 'horovod' or args.comm_backend == 'byteps': # Note that horovod.trainer._scale is default to num_workers, # thus trainer.update(1) will scale the gradients by 1./num_workers. # *num_workers* of Horovod is the number of GPUs. trainer.update(1, ignore_stale_grad=True) else: # gluon.trainer._scale is default to 1. # *num_workers* of Trainer is the number of machines. trainer.update(num_workers, ignore_stale_grad=True) if num_accumulated > 1: # set grad to zero for gradient accumulation model.zero_grad() # saving if step_num % save_interval == 0 or step_num >= num_steps: states_option(step_num, trainer, args.ckpt_dir, local_rank, 'Saving') if local_rank == 0: parameters_option(step_num, model, args.ckpt_dir, 'Saving') # logging if step_num % log_interval == 0: running_mlm_loss /= log_interval running_nsp_loss /= log_interval toc = time.time() logging.info( '[step {}], Loss mlm/nsp={:.5f}/{:.3f}, Acc mlm/nsp={:.3f}/{:.3f}, ' ' LR={:.7f}, grad_norm={:.4f}. Time cost={:.2f} s,' ' Throughput={:.1f}K tks/s, ETA={:.2f} h'.format( step_num, running_mlm_loss, running_nsp_loss, mlm_metric.get()[1], nsp_metric.get()[1], trainer.learning_rate, total_norm, toc - tic, running_num_tks.asnumpy().item() / (toc - tic) / 1000, (num_steps - step_num) / (step_num / (toc - train_start_time)) / 3600)) mlm_metric.reset() nsp_metric.reset() tic = time.time() running_mlm_loss = 0 running_nsp_loss = 0 running_num_tks = 0 logging.info('Finish training step: %d', step_num) mx.npx.waitall() train_end_time = time.time() logging.info('Train cost={:.1f} s'.format(train_end_time - train_start_time)) if local_rank == 0: model_name = args.model_name.replace('google', 'gluon') save_dir = os.path.join(args.ckpt_dir, model_name) final_save(model, save_dir, tokenizer, cfg)
rtd_preds = mx.np.round((mx.np.sign(rtd_scores) + 1) / 2).astype(np.int32) mlm_accuracy = accuracy(unmasked_tokens, mlm_preds, masked_weights) corrupted_mlm_accuracy = accuracy(unmasked_tokens, corrupted_tokens, masked_weights) rtd_accuracy = accuracy(rtd_labels, rtd_preds, length_masks) rtd_precision = accuracy(rtd_labels, rtd_preds, length_masks * rtd_preds) rtd_recall = accuracy(rtd_labels, rtd_preds, rtd_labels * rtd_preds) rtd_auc = auc(rtd_labels, rtd_probs, length_masks) writer.add_scalars( 'results', { 'mlm_accuracy': mlm_accuracy.asnumpy().item(), 'corrupted_mlm_accuracy': corrupted_mlm_accuracy.asnumpy().item(), 'rtd_accuracy': rtd_accuracy.asnumpy().item(), 'rtd_precision': rtd_precision.asnumpy().item(), 'rtd_recall': rtd_recall.asnumpy().item(), 'rtd_auc': rtd_auc }, step_num) if __name__ == '__main__': os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' os.environ['MXNET_USE_FUSION'] = '0' # Manually disable pointwise fusion args = parse_args() logging_config(args.output_dir, name='pretrain_owt') logging.debug('Random seed set to {}'.format(args.seed)) logging.info(args) set_seed(args.seed) if args.do_train: train(args)
def train(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) logging_config( args.output_dir, name='pretrain_owt_' + str(rank), # avoid race console=(local_rank == 0)) logging.info(args) logging.debug('Random seed set to {}'.format(args.seed)) set_seed(args.seed) logging.info('Training info: num_buckets: {}, ' 'num_workers: {}, rank: {}'.format(args.num_buckets, num_workers, rank)) cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l, args.max_seq_length, args.hidden_dropout_prob, args.attention_dropout_prob, args.generator_units_scale, args.generator_layers_scale) data_masker = ElectraMasker(tokenizer, args.max_seq_length, mask_prob=args.mask_prob, replace_prob=args.replace_prob) if args.from_raw_text: if args.cached_file_path and not os.path.exists(args.cached_file_path): os.mkdir(args.cached_file_path) get_dataset_fn = functools.partial( get_pretrain_data_text, max_seq_length=args.max_seq_length, short_seq_prob=args.short_seq_prob, tokenizer=tokenizer, circle_length=args.circle_length, repeat=args.repeat, cached_file_path=args.cached_file_path) logging.info( 'Processing and loading the training dataset from raw text.') else: logging.info('Loading the training dataset from local Numpy file.') get_dataset_fn = get_pretrain_data_npz data_train = get_dataset_fn(args.data, args.batch_size, shuffle=True, num_buckets=args.num_buckets, vocab=tokenizer.vocab, num_parts=num_workers, part_idx=rank, num_dataset_workers=args.num_dataset_workers, num_batch_workers=args.num_batch_workers) logging.info('Creating distributed trainer...') param_dict = model.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Collect differentiable parameters params = [p for p in param_dict.values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required num_accumulated = args.num_accumulated if num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' # backend specific implementation if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) num_train_steps = args.num_train_steps if args.warmup_steps is not None: warmup_steps = args.warmup_steps else: warmup_steps = int(num_train_steps * args.warmup_ratio) assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio' log_interval = args.log_interval save_interval = args.save_interval if args.save_interval is not None\ else num_train_steps // 50 logging.info( '#Total Training Steps={}, Warmup={}, Save Interval={}'.format( num_train_steps, warmup_steps, save_interval)) lr_scheduler = PolyScheduler(max_update=num_train_steps, base_lr=args.lr, warmup_begin_lr=0, pwr=1, final_lr=0, warmup_steps=warmup_steps, warmup_mode='linear') optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, } if args.optimizer == 'adamw': optimizer_params.update({ 'beta1': 0.9, 'beta2': 0.999, 'epsilon': 1e-6, 'correct_bias': False, }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) # TODO(zheyuye), How about data splitting, where to start re-training state_path = states_option(args.start_step, trainer, args.output_dir, local_rank, 'Loading') param_path = parameters_option(args.start_step, model, args.output_dir, 'Loading') # prepare the loss function mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss() rtd_loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() mlm_loss_fn.hybridize() rtd_loss_fn.hybridize() # prepare the records writer writer = None # only one process on each worker will write the tensorboardX's records to avoid race if args.do_eval and local_rank == 0: from tensorboardX import SummaryWriter record_path = os.path.join(args.output_dir, 'records') logging.info('Evaluation records saved in {}'.format(record_path)) writer = SummaryWriter(record_path) step_num = args.start_step finish_flag = False log_total_loss = 0 log_mlm_loss = 0 log_rtd_loss = 0 log_sample_num = 0 train_start_time = time.time() # start training train_loop_dataloader = grouper(repeat(data_train), len(ctx_l)) while step_num < num_train_steps: tic = time.time() for accum_idx in range(num_accumulated): sample_l = next(train_loop_dataloader) loss_l = [] mlm_loss_l = [] rtd_loss_l = [] for sample, ctx in zip(sample_l, ctx_l): if sample is None: continue # prepare data input_ids, segment_ids, valid_lengths = sample input_ids = input_ids.as_in_ctx(ctx) segment_ids = segment_ids.as_in_ctx(ctx) valid_lengths = valid_lengths.as_in_ctx(ctx) masked_input = data_masker.dynamic_masking( mx.nd, input_ids, valid_lengths) masked_input_ids = masked_input.input_ids length_masks = masked_input.masks unmasked_tokens = masked_input.unmasked_tokens masked_positions = masked_input.masked_positions masked_weights = masked_input.masked_weights log_sample_num += len(masked_input_ids) with mx.autograd.record(): mlm_scores, rtd_scores, corrupted_tokens, labels = model( masked_input_ids, segment_ids, valid_lengths, unmasked_tokens, masked_positions) denominator = (masked_weights.sum() + 1e-6) * num_accumulated * len(ctx_l) mlm_loss = mlm_loss_fn( mx.npx.reshape(mlm_scores, (-5, -1)), unmasked_tokens.reshape( (-1, )), masked_weights.reshape( (-1, 1))).sum() / denominator denominator = (length_masks.sum() + 1e-6) * num_accumulated * len(ctx_l) rtd_loss = rtd_loss_fn(rtd_scores, labels, length_masks).sum() / denominator output = ElectraOutput( mlm_scores=mlm_scores, rtd_scores=rtd_scores, rtd_labels=labels, corrupted_tokens=corrupted_tokens, ) mlm_loss_l.append(mlm_loss) rtd_loss_l.append(rtd_loss) loss = (args.gen_weight * mlm_loss + args.disc_weight * rtd_loss) loss_l.append(loss) for loss in loss_l: loss.backward() # All Reduce the Step Loss log_mlm_loss += sum( [ele.as_in_ctx(ctx_l[0]) for ele in mlm_loss_l]).asnumpy() log_rtd_loss += sum( [ele.as_in_ctx(ctx_l[0]) for ele in rtd_loss_l]).asnumpy() log_total_loss += sum([ele.as_in_ctx(ctx_l[0]) for ele in loss_l]).asnumpy() # update trainer.allreduce_grads() total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_workers) if args.comm_backend == 'horovod': # Note that horovod.trainer._scale is default to num_workers, # thus trainer.update(1) will scale the gradients by 1./num_workers trainer.update(1, ignore_stale_grad=True) else: # gluon.trainer._scale is default to 1 trainer.update(num_workers, ignore_stale_grad=True) total_norm = total_norm / num_workers step_num += 1 if num_accumulated > 1: # set grad to zero for gradient accumulation model.zero_grad() # saving if step_num % save_interval == 0 or step_num >= num_train_steps: if is_master_node: states_option(step_num, trainer, args.output_dir, local_rank, 'Saving') if local_rank == 0: param_path = parameters_option(step_num, model, args.output_dir, 'Saving') # logging if step_num % log_interval == 0: # Output the loss of per step log_mlm_loss /= log_interval log_rtd_loss /= log_interval log_total_loss /= log_interval toc = time.time() logging.info('[step {}], Loss mlm/rtd/total={:.4f}/{:.4f}/{:.4f},' ' LR={:.6f}, grad_norm={:.4f}. Time cost={:.2f},' ' Throughput={:.2f} samples/s, ETA={:.2f}h'.format( step_num, log_mlm_loss, log_rtd_loss, log_total_loss, trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic), (num_train_steps - step_num) / (step_num / (toc - train_start_time)) / 3600)) tic = time.time() if args.do_eval: evaluation(writer, step_num, masked_input, output) if writer is not None: writer.add_scalars( 'loss', { 'total_loss': log_total_loss, 'mlm_loss': log_mlm_loss, 'rtd_loss': log_rtd_loss }, step_num) log_mlm_loss = 0 log_rtd_loss = 0 log_total_loss = 0 log_sample_num = 0 logging.info('Finish training step: %d', step_num) if is_master_node: state_path = states_option(step_num, trainer, args.output_dir, local_rank, 'Saving') if local_rank == 0: param_path = parameters_option(step_num, model, args.output_dir, 'Saving') mx.npx.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_start_time)) if writer is not None: writer.close() if local_rank == 0: model_name = args.model_name.replace('google', 'gluon') save_dir = os.path.join(args.output_dir, model_name) final_save(model, save_dir, tokenizer)