def get_lr_scheduler(args, train_loader): if args.optim_phase == 'Factor': every_lr_decay_step = args.every_lr_decay_step lr_scheduler = FactorScheduler(step=every_lr_decay_step, factor=0.1) elif args.optim_phase == 'MultiFactor': lr_decay_steps = [ len(train_loader) * ep for ep in args.lr_decay_epochs ] lr_scheduler = MultiFactorScheduler(step=lr_decay_steps, factor=0.1) elif args.optim_phase == 'Poly': max_update_step = args.epochs lr_scheduler = PolyScheduler(max_update=max_update_step) elif args.optim_phase == 'Cosine': max_update_step = args.epochs lr_scheduler = CosineScheduler(max_update=max_update_step) else: raise ValueError('Invalid phase {}'.format(args.optim_phase)) return lr_scheduler
def get_optimizer(cfg, updates_per_epoch): max_update = int(updates_per_epoch * cfg.num_train_epochs) warmup_steps = int(updates_per_epoch * cfg.num_train_epochs * cfg.warmup_portion) if cfg.lr_scheduler == 'triangular': assert warmup_steps < max_update lr_scheduler = PolyScheduler(max_update=max_update, base_lr=cfg.lr, warmup_begin_lr=cfg.begin_lr, pwr=1, final_lr=cfg.final_lr, warmup_steps=warmup_steps, warmup_mode='linear') elif cfg.lr_scheduler == 'inv_sqrt': warmup_steps = int(updates_per_epoch * cfg.num_train_epochs * cfg.warmup_portion) lr_scheduler = InverseSquareRootScheduler(warmup_steps=warmup_steps, base_lr=cfg.lr, warmup_init_lr=cfg.begin_lr) elif cfg.lr_scheduler == 'constant': lr_scheduler = None elif cfg.lr_scheduler == 'cosine': max_update = int(updates_per_epoch * cfg.num_train_epochs) warmup_steps = int(updates_per_epoch * cfg.num_train_epochs * cfg.warmup_portion) assert warmup_steps < max_update lr_scheduler = CosineScheduler(max_update=max_update, base_lr=cfg.lr, final_lr=cfg.final_lr, warmup_steps=warmup_steps, warmup_begin_lr=cfg.begin_lr) else: raise ValueError('Unsupported lr_scheduler="{}"'.format( cfg.lr_scheduler)) optimizer_params = { 'learning_rate': cfg.lr, 'wd': cfg.wd, 'lr_scheduler': lr_scheduler } optimizer = cfg.optimizer additional_params = {key: value for key, value in cfg.optimizer_params} optimizer_params.update(additional_params) return optimizer, optimizer_params, max_update
def train(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) setup_logging(args, local_rank) cfg, tokenizer, qa_net, use_segmentation = \ get_network(args.model_name, ctx_l, args.classifier_dropout, args.param_checkpoint, args.backbone_path) logging.info('Prepare training data') train_features = get_squad_features(args, tokenizer, segment='train') dataset_processor = SquadDatasetProcessor( tokenizer=tokenizer, doc_stride=args.doc_stride, max_seq_length=args.max_seq_length, max_query_length=args.max_query_length) logging.info('Processing the Training data:') train_dataset, num_answer_mismatch, num_unreliable \ = dataset_processor.get_train(train_features, skip_unreliable=True) logging.info( 'Done! #Unreliable Span={} / #Mismatched Answer={} / #Total={}'.format( num_unreliable, num_answer_mismatch, len(train_features))) # Get dataset statistics num_impossible = 0 for sample in train_dataset: num_impossible += sample.is_impossible logging.info('Before Chunking, #Train/Is Impossible = {}/{}'.format( len(train_features), sum([ele.is_impossible for ele in train_features]))) logging.info('After Chunking, #Train Sample/Is Impossible = {}/{}'.format( len(train_dataset), num_impossible)) # Shuffle the dataset using a fixed seed across all workers rs = np.random.RandomState(args.pre_shuffle_seed) rs.shuffle(train_dataset) sampler = SplitSampler(len(train_dataset), num_parts=num_workers, part_index=rank, even_size=True) train_dataloader = mx.gluon.data.DataLoader( train_dataset, batchify_fn=dataset_processor.BatchifyFunction, batch_size=args.batch_size, num_workers=0, sampler=sampler) if 'electra' in args.model_name: # Froze parameters, does not work for albert model since parameters in all layers are shared if args.untunable_depth > 0: qa_net.backbone.frozen_params(args.untunable_depth) if args.layerwise_decay > 0: qa_net.backbone.apply_layerwise_decay(args.layerwise_decay) logging.info('Creating distributed trainer...') # Collect differentiable parameters param_dict = qa_net.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in qa_net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 params = [p for p in param_dict.values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required num_accumulated = args.num_accumulated if num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' # backend specific implementation if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) epoch_size = (len(train_dataloader) + len(ctx_l) - 1) // len(ctx_l) if args.num_train_steps is not None: num_train_steps = args.num_train_steps else: num_train_steps = int(args.epochs * epoch_size / args.num_accumulated) if args.warmup_steps is not None: warmup_steps = args.warmup_steps else: warmup_steps = int(num_train_steps * args.warmup_ratio) assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio' log_interval = args.log_interval save_interval = args.save_interval if args.save_interval is not None\ else epoch_size // args.num_accumulated logging.info( '#Total Training Steps={}, Warmup={}, Save Interval={}'.format( num_train_steps, warmup_steps, save_interval)) # set up optimization lr_scheduler = PolyScheduler(max_update=num_train_steps, base_lr=args.lr, warmup_begin_lr=0, pwr=1, final_lr=0, warmup_steps=warmup_steps, warmup_mode='linear') optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, } adam_betas = eval(args.adam_betas) if args.optimizer == 'adamw': optimizer_params.update({ 'beta1': adam_betas[0], 'beta2': adam_betas[1], 'epsilon': args.adam_epsilon, 'correct_bias': False, }) elif args.optimizer == 'adam': optimizer_params.update({ 'beta1': adam_betas[0], 'beta2': adam_betas[1], 'epsilon': args.adam_epsilon, }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) log_span_loss = 0 log_answerable_loss = 0 log_total_loss = 0 log_sample_num = 0 global_tic = time.time() tic = time.time() for step_num, batch_data in enumerate( grouper(repeat(train_dataloader), len(ctx_l) * num_accumulated)): for sample_l in grouper(batch_data, len(ctx_l)): loss_l = [] span_loss_l = [] answerable_loss_l = [] for sample, ctx in zip(sample_l, ctx_l): if sample is None: continue # Copy the data to device tokens = sample.data.as_in_ctx(ctx) log_sample_num += len(tokens) segment_ids = sample.segment_ids.as_in_ctx( ctx) if use_segmentation else None valid_length = sample.valid_length.as_in_ctx(ctx) p_mask = sample.masks.as_in_ctx(ctx) gt_start = sample.gt_start.as_in_ctx(ctx).astype(np.int32) gt_end = sample.gt_end.as_in_ctx(ctx).astype(np.int32) is_impossible = sample.is_impossible.as_in_ctx(ctx).astype( np.int32) batch_idx = mx.np.arange(tokens.shape[0], dtype=np.int32, ctx=ctx) p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask with mx.autograd.record(): start_logits, end_logits, answerable_logits \ = qa_net(tokens, segment_ids, valid_length, p_mask, gt_start) sel_start_logits = start_logits[batch_idx, gt_start] sel_end_logits = end_logits[batch_idx, gt_end] sel_answerable_logits = answerable_logits[batch_idx, is_impossible] span_loss = -0.5 * (sel_start_logits + sel_end_logits).mean() answerable_loss = -0.5 * sel_answerable_logits.mean() loss = span_loss + answerable_loss loss_l.append(loss) span_loss_l.append(span_loss) answerable_loss_l.append(answerable_loss) for loss in loss_l: loss.backward() # All Reduce the Step Loss log_span_loss += sum( [ele.as_in_ctx(ctx_l[0]) for ele in span_loss_l]).asnumpy() log_total_loss += sum([ele.as_in_ctx(ctx_l[0]) for ele in loss_l]).asnumpy() log_answerable_loss += sum([ ele.as_in_ctx(ctx_l[0]) for ele in answerable_loss_l ]).asnumpy() # update trainer.allreduce_grads() if args.max_grad_norm > 0: total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_workers) else: total_norm = grad_global_norm(params) if args.comm_backend == 'horovod': # Note that horovod.trainer._scale is default to num_workers, # thus trainer.update(1) will scale the gradients by 1./num_workers trainer.update(1, ignore_stale_grad=True) else: # gluon.trainer._scale is default to 1 trainer.update(num_workers, ignore_stale_grad=True) total_norm = total_norm / num_workers if args.num_accumulated > 1: # set grad to zero for gradient accumulation qa_net.zero_grad() # saving if local_rank == 0 and (step_num + 1) % save_interval == 0 or ( step_num + 1) >= num_train_steps: version_prefix = 'squad' + args.version ckpt_name = '{}_{}_{}.params'.format(args.model_name, version_prefix, (step_num + 1)) params_saved = os.path.join(args.output_dir, ckpt_name) qa_net.save_parameters(params_saved) ckpt_candidates = [ f for f in os.listdir(args.output_dir) if f.endswith('.params') ] # keep last `max_saved_ckpt` checkpoints if len(ckpt_candidates) > args.max_saved_ckpt: ckpt_candidates.sort(key=lambda ele: (len(ele), ele)) os.remove(os.path.join(args.output_dir, ckpt_candidates[0])) logging.info('Params saved in: {}'.format(params_saved)) # logging if (step_num + 1) % log_interval == 0: log_span_loss /= log_sample_num log_answerable_loss /= log_sample_num log_total_loss /= log_sample_num toc = time.time() logging.info( 'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},' ' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s' ' ETA={:.2f}h'.format( (step_num + 1), num_train_steps, log_span_loss, log_answerable_loss, log_total_loss, trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic), (num_train_steps - (step_num + 1)) / ((step_num + 1) / (toc - global_tic)) / 3600)) tic = time.time() log_span_loss = 0 log_answerable_loss = 0 log_total_loss = 0 log_sample_num = 0 num_samples_per_update = 0 if (step_num + 1) >= num_train_steps: toc = time.time() logging.info('Finish training step: {} within {} hours'.format( step_num + 1, (toc - global_tic) / 3600)) break return params_saved
def train(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) task = get_task(args.task_name) #setup_logging(args, local_rank) level = logging.INFO detail_dir = os.path.join(args.output_dir, args.task_name) if not os.path.exists(detail_dir): os.mkdir(detail_dir) logging_config( detail_dir, name='train_{}_{}_'.format(args.task_name, args.model_name) + str(rank), # avoid race level=level, console=(local_rank == 0)) logging.info(args) cfg, tokenizer, classify_net, use_segmentation = \ get_network(args.model_name, ctx_l, args.param_checkpoint, args.backbone_path, task) logging.info('Prepare training data') train_data, _ = get_task_data(args, tokenizer, segment='train', task=task) train_batchify = bf.Group(bf.Group(bf.Pad(), bf.Pad(), bf.Stack()), bf.Stack()) epoch_num_updates = len(train_data) // args.batch_size max_update = epoch_num_updates * args.epochs warmup_steps = int(np.ceil(max_update * args.warmup_ratio)) dataloader = DataLoader(train_data, batch_size=args.batch_size, batchify_fn=train_batchify, num_workers=4, shuffle=True) dataloader = grouper(repeat(dataloader), len(ctx_l)) param_dict = classify_net.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in classify_net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Set grad_req if gradient accumulation is required params = [p for p in param_dict.values() if p.grad_req != 'null'] num_accumulated = args.num_accumulated if num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) lr_scheduler = PolyScheduler(max_update=max_update, base_lr=args.lr, warmup_begin_lr=0.0, pwr=1, final_lr=0.0, warmup_steps=warmup_steps, warmup_mode='linear') optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler } if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(classify_net.collect_params(), 'adamw', optimizer_params) if args.task_name == 'sts': loss_function = gluon.loss.L2Loss() else: loss_function = gluon.loss.SoftmaxCELoss() #prepare loss function log_loss = 0 log_gnorm = 0 log_step = 0 if args.log_interval > 0: log_interval = args.log_interval else: log_interval = int(epoch_num_updates * 0.5) for i in range(max_update): sample_l = next(dataloader) loss_l = [] for sample, ctx in zip(sample_l, ctx_l): (token_ids, token_types, valid_length), label = sample # Move to the corresponding context token_ids = mx.np.array(token_ids, ctx=ctx) token_types = mx.np.array(token_types, ctx=ctx) valid_length = mx.np.array(valid_length, ctx=ctx) label = mx.np.array(label, ctx=ctx) with mx.autograd.record(): scores = classify_net(token_ids, token_types, valid_length) loss = loss_function(scores, label).mean() / len(ctx_l) loss_l.append(loss) for loss in loss_l: loss.backward() trainer.allreduce_grads() # Begin Norm Clipping total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm) trainer.update(1.0) step_loss = sum([loss.asnumpy() for loss in loss_l]) log_loss += step_loss log_gnorm += total_norm log_step += 1 if log_step >= log_interval or i == max_update - 1: logging.info( '[Iter {} / {}] avg {} = {:.2f}, avg gradient norm = {:.2f}'. format(i + 1, max_update, 'nll', log_loss / log_step, log_gnorm / log_step)) log_loss = 0 log_gnorm = 0 log_step = 0 if local_rank == 0 and (i == max_update - 1 or i % (max_update // args.epochs) == 0 and i > 0): ckpt_name = '{}_{}_{}.params'.format(args.model_name, args.task_name, (i + 1)) params_saved = os.path.join(detail_dir, ckpt_name) classify_net.save_parameters(params_saved) logging.info('Params saved in: {}'.format(params_saved))
def train(args): _, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) level = logging.DEBUG if args.verbose else logging.INFO logging_config(args.ckpt_dir, name='pretrain_bert_' + str(rank), # avoid race level=level, console=(local_rank == 0)) logging.info(args) logging.debug('Random seed set to {}'.format(args.seed)) set_seed(args.seed) logging.info('Training info: num_buckets: {}, ' 'num_workers: {}, rank: {}'.format( args.num_buckets, num_workers, rank)) cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l, args.max_seq_length) if args.raw: get_dataset_fn = functools.partial(get_pretrain_data_text, max_seq_length=args.max_seq_length, short_seq_prob=args.short_seq_prob, masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, whole_word_mask=args.whole_word_mask, random_next_sentence=args.random_next_sentence, tokenizer=tokenizer, circle_length=args.circle_length, repeat=args.repeat, dataset_cached=args.dataset_cached, num_max_dataset_cached=args.num_max_dataset_cached) else: get_dataset_fn = get_pretrain_data_npz data_train = get_dataset_fn(args.data, args.batch_size, shuffle=True, num_buckets=args.num_buckets, vocab=tokenizer.vocab, num_parts=num_workers, part_idx=rank, num_dataset_workers=args.num_dataset_workers, num_batch_workers=args.num_batch_workers) param_dict = model.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Set grad_req if gradient accumulation is required params = [p for p in param_dict.values() if p.grad_req != 'null'] num_accumulated = args.num_accumulated if num_accumulated > 1: logging.info('Using gradient accumulation. Effective global batch size = {}' .format(num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' num_steps = args.num_steps warmup_steps = int(num_steps * args.warmup_ratio) log_interval = args.log_interval save_interval = args.ckpt_interval logging.info('#Total Training Steps={}, Warmup Steps={}, Save Interval={}' .format(num_steps, warmup_steps, save_interval)) lr_scheduler = PolyScheduler(max_update=num_steps, base_lr=args.lr, warmup_begin_lr=0, pwr=1, final_lr=0, warmup_steps=warmup_steps, warmup_mode='linear') optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, } if args.optimizer == 'adamw': optimizer_params.update({'beta1': 0.9, 'beta2': 0.999, 'epsilon': 1e-6, 'correct_bias': False, }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) elif args.comm_backend == 'byteps': trainer = bps.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) parameters_option(args.start_step, model, args.ckpt_dir, 'Loading') states_option(args.start_step, trainer, args.ckpt_dir, local_rank, 'Loading') if args.comm_backend == 'byteps': trainer._init_params() # backend specific implementation if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) # prepare the loss function nsp_loss_fn = mx.gluon.loss.SoftmaxCELoss() mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss() nsp_loss_fn.hybridize() mlm_loss_fn.hybridize() mlm_metric = MaskedAccuracy() nsp_metric = MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() step_num = args.start_step running_mlm_loss, running_nsp_loss = 0., 0. running_num_tks = 0 train_start_time = time.time() tic = time.time() # start training train_loop_dataloader = grouper(repeat(data_train), len(ctx_l)) while step_num < num_steps: for _ in range(num_accumulated): sample_l = next(train_loop_dataloader) mlm_loss_l = [] nsp_loss_l = [] loss_l = [] ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] for sample, ctx in zip(sample_l, ctx_l): # prepare data (input_id, masked_id, masked_position, masked_weight, \ next_sentence_label, segment_id, valid_length) = sample input_id = input_id.as_in_ctx(ctx) masked_id = masked_id.as_in_ctx(ctx) masked_position = masked_position.as_in_ctx(ctx) masked_weight = masked_weight.as_in_ctx(ctx) next_sentence_label = next_sentence_label.as_in_ctx(ctx) segment_id = segment_id.as_in_ctx(ctx) valid_length = valid_length.as_in_ctx(ctx) with mx.autograd.record(): _, _, nsp_score, mlm_scores = model(input_id, segment_id, valid_length, masked_position) denominator = (masked_weight.sum() + 1e-8) * num_accumulated * len(ctx_l) mlm_scores_r = mx.npx.reshape(mlm_scores, (-5, -1)) masked_id_r = masked_id.reshape((-1,)) mlm_loss = mlm_loss_fn( mlm_scores_r, masked_id_r, masked_weight.reshape((-1, 1))).sum() / denominator denominator = num_accumulated * len(ctx_l) nsp_loss = nsp_loss_fn( nsp_score, next_sentence_label).mean() / denominator mlm_loss_l.append(mlm_loss) nsp_loss_l.append(nsp_loss) loss_l.append(mlm_loss + nsp_loss) mask_label_list.append(masked_id_r) mask_pred_list.append(mlm_scores_r) mask_weight_list.append(masked_weight.reshape((-1,))) ns_label_list.append(next_sentence_label) ns_pred_list.append(nsp_score) running_num_tks += valid_length.sum().as_in_ctx(mx.cpu()) for loss in loss_l: loss.backward() running_mlm_loss += sum([ele.as_in_ctx(mx.cpu()) for ele in mlm_loss_l]).asnumpy().item() running_nsp_loss += sum([ele.as_in_ctx(mx.cpu()) for ele in nsp_loss_l]).asnumpy().item() mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) nsp_metric.update(ns_label_list, ns_pred_list) # update trainer.allreduce_grads() total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_workers) total_norm = total_norm / num_workers if args.comm_backend == 'horovod' or args.comm_backend == 'byteps': # Note that horovod.trainer._scale is default to num_workers, # thus trainer.update(1) will scale the gradients by 1./num_workers trainer.update(1, ignore_stale_grad=True) else: # gluon.trainer._scale is default to 1 trainer.update(num_workers, ignore_stale_grad=True) if num_accumulated > 1: # set grad to zero for gradient accumulation model.zero_grad() step_num += 1 # saving if step_num % save_interval == 0 or step_num >= num_steps: states_option(step_num, trainer, args.ckpt_dir, local_rank, 'Saving') if local_rank == 0: parameters_option(step_num, model, args.ckpt_dir, 'Saving') # logging if step_num % log_interval == 0: running_mlm_loss /= log_interval running_nsp_loss /= log_interval toc = time.time() logging.info( '[step {}], Loss mlm/nsp={:.5f}/{:.3f}, Acc mlm/nsp={:.3f}/{:.3f}, ' ' LR={:.7f}, grad_norm={:.4f}. Time cost={:.2f} s,' ' Throughput={:.1f}K tks/s, ETA={:.2f} h'.format( step_num, running_mlm_loss, running_nsp_loss, mlm_metric.get()[1], nsp_metric.get()[1], trainer.learning_rate, total_norm, toc - tic, running_num_tks.asnumpy().item() / (toc - tic) / 1000, (num_steps - step_num) / (step_num / (toc - train_start_time)) / 3600)) mlm_metric.reset() nsp_metric.reset() tic = time.time() running_mlm_loss = 0 running_nsp_loss = 0 running_num_tks = 0 logging.info('Finish training step: %d', step_num) mx.npx.waitall() train_end_time = time.time() logging.info('Train cost={:.1f} s'.format(train_end_time - train_start_time)) if local_rank == 0: model_name = args.model_name.replace('google', 'gluon') save_dir = os.path.join(args.ckpt_dir, model_name) final_save(model, save_dir, tokenizer, cfg)
def train(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) logging.info('Training info: num_buckets: {}, ' 'num_workers: {}, rank: {}'.format(args.num_buckets, num_workers, rank)) cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l, args.max_seq_length, args.hidden_dropout_prob, args.attention_dropout_prob, args.generator_units_scale, args.generator_layers_scale) data_masker = ElectraMasker(tokenizer, args.max_seq_length, args.mask_prob) if args.from_raw_text: if args.cached_file_path and not os.path.exists(args.cached_file_path): os.mkdir(args.cached_file_path) get_dataset_fn = functools.partial( get_pretrain_data_text, max_seq_length=args.max_seq_length, short_seq_prob=args.short_seq_prob, tokenizer=tokenizer, circle_length=args.circle_length, repeat=args.repeat, cached_file_path=args.cached_file_path) logging.info( 'Processing and loading the training dataset from raw text.') else: logging.info('Loading the training dataset from local Numpy file.') get_dataset_fn = get_pretrain_data_npz data_train = get_dataset_fn(args.data, args.batch_size, shuffle=True, num_buckets=args.num_buckets, vocab=tokenizer.vocab, num_parts=num_workers, part_idx=rank, num_dataset_workers=args.num_dataset_workers, num_batch_workers=args.num_batch_workers) logging.info('Creating distributed trainer...') param_dict = model.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Collect differentiable parameters params = [p for p in param_dict.values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required if args.num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(args.num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' # backend specific implementation if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) num_train_steps = args.num_train_steps if args.warmup_steps is not None: warmup_steps = args.warmup_steps else: warmup_steps = int(num_train_steps * args.warmup_ratio) assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio' log_interval = args.log_interval save_interval = args.save_interval if args.save_interval is not None\ else num_train_steps // 50 logging.info( '#Total Training Steps={}, Warmup={}, Save Interval={}'.format( num_train_steps, warmup_steps, save_interval)) lr_scheduler = PolyScheduler(max_update=num_train_steps, base_lr=args.lr, warmup_begin_lr=0, pwr=1, final_lr=0, warmup_steps=warmup_steps, warmup_mode='linear') optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, } if args.optimizer == 'adamw': optimizer_params.update({ 'beta1': 0.9, 'beta2': 0.999, 'epsilon': 1e-6, 'correct_bias': False, }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) # TODO(zheyuye), How about data splitting, where to start re-training state_path = states_option(args.start_step, trainer, args.output_dir, local_rank, 'Loading') param_path = parameters_option(args.start_step, model, args.output_dir, 'Loading') # prepare the loss function mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss() rtd_loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() mlm_loss_fn.hybridize() rtd_loss_fn.hybridize() # prepare the records writer writer = None if args.do_eval and local_rank == 0: from tensorboardX import SummaryWriter record_path = os.path.join(args.output_dir, 'records') logging.info('Evaluation records saved in {}'.format(record_path)) writer = SummaryWriter(record_path) step_num = args.start_step finish_flag = False num_samples_per_update = 0 loss_denom = float(len(ctx_l) * args.num_accumulated * num_workers) log_total_loss = 0 log_mlm_loss = 0 log_rtd_loss = 0 log_sample_num = 0 train_start_time = time.time() if args.num_accumulated != 1: # set grad to zero for gradient accumulation model.zero_grad() # start training train_loop_dataloader = grouper(repeat(data_train), len(ctx_l)) while step_num < num_train_steps: tic = time.time() for accum_idx in range(args.num_accumulated): sample_l = next(train_loop_dataloader) loss_l = [] mlm_loss_l = [] rtd_loss_l = [] for sample, ctx in zip(sample_l, ctx_l): if sample is None: continue # prepare data input_ids, segment_ids, valid_lengths = sample input_ids = input_ids.as_in_ctx(ctx) segment_ids = segment_ids.as_in_ctx(ctx) valid_lengths = valid_lengths.as_in_ctx(ctx) masked_input = data_masker.dynamic_masking( mx.nd, input_ids, valid_lengths) masked_input_ids = masked_input.input_ids length_masks = masked_input.masks unmasked_tokens = masked_input.unmasked_tokens masked_positions = masked_input.masked_positions masked_weights = masked_input.masked_weights log_sample_num += len(masked_input_ids) num_samples_per_update += len(masked_input_ids) with mx.autograd.record(): mlm_scores, rtd_scores, corrupted_tokens, labels = model( masked_input_ids, segment_ids, valid_lengths, unmasked_tokens, masked_positions) # the official implementation takes the sum of each batch inside the loss function # while SigmoidBinaryCrossEntropyLoss and SoftmaxCELoss takes the mean value mlm_loss = mlm_loss_fn( mlm_scores, unmasked_tokens, masked_weights.reshape( -1)).mean() / (masked_weights.mean() + 1e-6) rtd_loss = rtd_loss_fn( rtd_scores, labels, length_masks).mean() / (length_masks.mean() + 1e-6) output = ElectraOutput( mlm_scores=mlm_scores, rtd_scores=rtd_scores, rtd_labels=labels, corrupted_tokens=corrupted_tokens, ) mlm_loss_l.append(mlm_loss) rtd_loss_l.append(rtd_loss) loss = (args.gen_weight * mlm_loss + args.disc_weight * rtd_loss) / loss_denom loss_l.append(loss) for loss in loss_l: loss.backward() # All Reduce the Step Loss log_mlm_loss += sum( [ele.as_in_ctx(ctx_l[0]) for ele in mlm_loss_l]).asnumpy() log_rtd_loss += sum( [ele.as_in_ctx(ctx_l[0]) for ele in rtd_loss_l]).asnumpy() log_total_loss += sum([ele.as_in_ctx(ctx_l[0]) for ele in loss_l]).asnumpy() * loss_denom # update trainer.allreduce_grads() # Here, the accumulated gradients are # \sum_{n=1}^N g_n / loss_denom # Thus, in order to clip the average gradient # \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm # We need to change the ratio to be # \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_samples_per_update / loss_denom) total_norm = total_norm / (num_samples_per_update / loss_denom) trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True) step_num += 1 if args.num_accumulated != 1: # set grad to zero for gradient accumulation model.zero_grad() # saving if step_num % save_interval == 0 or step_num >= num_train_steps: if is_master_node: states_option(step_num, trainer, args.output_dir, local_rank, 'Saving') if local_rank == 0: param_path = parameters_option(step_num, model, args.output_dir, 'Saving') # logging if step_num % log_interval == 0 and local_rank == 0: # Output the loss of per step log_mlm_loss /= log_interval log_rtd_loss /= log_interval log_total_loss /= log_interval toc = time.time() logging.info('[step {}], Loss mlm/rtd/total={:.4f}/{:.4f}/{:.4f},' ' LR={:.6f}, grad_norm={:.4f}. Time cost={:.2f},' ' Throughput={:.2f} samples/s, ETA={:.2f}h'.format( step_num, log_mlm_loss, log_rtd_loss, log_total_loss, trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic), (num_train_steps - step_num) / (step_num / (toc - train_start_time)) / 3600)) tic = time.time() if args.do_eval: evaluation(writer, step_num, masked_input, output) writer.add_scalars( 'loss', { 'total_loss': log_total_loss, 'mlm_loss': log_mlm_loss, 'rtd_loss': log_rtd_loss }, step_num) log_mlm_loss = 0 log_rtd_loss = 0 log_total_loss = 0 log_sample_num = 0 num_samples_per_update = 0 logging.info('Finish training step: %d', step_num) if is_master_node: state_path = states_option(step_num, trainer, args.output_dir, local_rank, 'Saving') if local_rank == 0: param_path = parameters_option(step_num, model, args.output_dir, 'Saving') mx.npx.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_start_time)) if writer is not None: writer.close() if local_rank == 0: model_name = args.model_name.replace('google', 'gluon') save_dir = os.path.join(args.output_dir, model_name) final_save(model, save_dir, tokenizer)