def train(args): _, num_parts, rank, local_rank, _, ctx_l = init_comm( args.comm_backend, args.gpus) if args.comm_backend == 'horovod': logging_config( args.save_dir, name=f'train_transformer_rank{rank}_local{local_rank}_{num_parts}', console=(rank == 0)) logging.info(args) else: logging_config(args.save_dir, name='train_transformer', console=True) logging.info(args) use_amp = args.fp16 if use_amp: from mxnet import amp src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) tgt_tokenizer = create_tokenizer(args.tgt_tokenizer, args.tgt_subword_model_path, args.tgt_vocab_path) base_tgt_tokenizer = MosesTokenizer(args.tgt_lang) src_vocab = src_tokenizer.vocab tgt_vocab = tgt_tokenizer.vocab train_src_data, train_tgt_data = load_dataset_with_cache( args.train_src_corpus, args.train_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache, local_rank, max_src_length=args.max_src_length, max_tgt_length=args.max_tgt_length, pretokenized=not args.tokenize) dev_src_data, dev_tgt_data = load_dataset_with_cache( args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache, local_rank, pretokenized=not args.tokenize) tgt_detok_sentences = [] tgt_raw_sentences = [] with open(args.dev_tgt_corpus, 'r') as in_f: for line in in_f: tgt_detok_sentences.append( base_tgt_tokenizer.decode( tgt_tokenizer.decode(line.split()).split())) with open(args.dev_tgt_raw_corpus, 'r') as in_f: for line in in_f: tgt_raw_sentences.append(line.strip()) data_train = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data)) ]) val_samples = [ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data)) ] if args.comm_backend == 'horovod': slice_begin = rank * (len(val_samples) // num_parts) slice_end = min((rank + 1) * (len(val_samples) // num_parts), len(val_samples)) data_val = gluon.data.SimpleDataset(val_samples[slice_begin:slice_end]) else: data_val = gluon.data.SimpleDataset(val_samples) # Construct the model + loss function if args.cfg.endswith('.yml'): cfg = TransformerModel.get_cfg().clone_merge(args.cfg) else: cfg = TransformerModel.get_cfg(args.cfg) cfg.defrost() cfg.MODEL.src_vocab_size = len(src_vocab) cfg.MODEL.tgt_vocab_size = len(tgt_vocab) cfg.MODEL.layout = 'TN' cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() for v in model.collect_params().values(): if v.grad_req != 'null': v.grad_req = 'add' # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 param_dict = deduplicate_param_dict(model.collect_params()) inference_model = TransformerInference(model=model) inference_model.hybridize() if local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) label_smooth_loss = LabelSmoothCrossEntropyLoss( num_labels=len(tgt_vocab), alpha=args.label_smooth_alpha, from_logits=False) label_smooth_loss.hybridize() # Construct the beam search sampler scorer = BeamSearchScorer(alpha=args.lp_alpha, K=args.lp_k, from_logits=False) beam_search_sampler = BeamSearchSampler(beam_size=args.beam_size, decoder=inference_model, vocab_size=len(tgt_vocab), eos_id=tgt_vocab.eos_id, scorer=scorer, stochastic=False, max_length_a=args.max_length_a, max_length_b=args.max_length_b) logging.info(beam_search_sampler) if args.comm_backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) # Construct the trainer if args.lr is None: base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt( args.warmup_steps) else: base_lr = args.lr lr_scheduler = InverseSquareRootScheduler( warmup_steps=args.warmup_steps, base_lr=base_lr, warmup_init_lr=args.warmup_init_lr) optimizer_params = { 'learning_rate': args.lr, 'beta1': 0.9, 'beta2': 0.997, 'epsilon': 1e-9, 'lr_scheduler': lr_scheduler, 'wd': args.wd } user_provided_ptimizer_params = json.loads(args.optimizer_params) optimizer_params.update(user_provided_ptimizer_params) if args.fp16: optimizer_params.update({'multi_precision': True}) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler( lengths=[(ele[2], ele[3]) for ele in data_train], max_num_tokens=args.max_num_tokens, max_num_sentences=args.max_num_sentences, shuffle=True, seed=args.seed) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': raise NotImplementedError( 'FixedBucketSampler does not support horovod at present') if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError # TODO(sxjscience) Support auto-bucket-size tuning train_batch_sampler = FixedBucketSampler(lengths=[ (ele[2], ele[3]) for ele in data_train ], batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True, bucket_scheme=bucket_scheme, seed=args.seed) else: raise NotImplementedError num_updates_per_epoch = int( math.ceil( len(train_batch_sampler) / (num_parts * len(ctx_l) * args.num_accumulated))) # Convert the batch sampler to multiple shards if num_parts > 1: train_batch_sampler = ShardedIterator(train_batch_sampler, num_parts=num_parts, part_index=rank, even_size=True, seed=args.seed + 1000 * rank) logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader( data_train, batch_sampler=train_batch_sampler, batchify_fn=batchify_fn, num_workers=0) val_data_loader = gluon.data.DataLoader(data_val, batch_size=args.val_batch_size, batchify_fn=batchify_fn, num_workers=0, shuffle=False) params = [p for p in param_dict.values() if p.grad_req != 'null'] model_averager = AverageSGDTracker(param_dict) log_start_time = time.time() num_params, num_fixed_params = None, None # TODO(sxjscience) Add a log metric class log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] # Maintain the denominator of the loss. log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l] log_tgt_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l] log_avg_grad_norm = 0 log_iter_num = 0 if local_rank == 0: writer = SummaryWriter( logdir=os.path.join(args.save_dir, 'tensorboard')) if use_amp: amp.init_trainer(trainer) train_multi_data_loader = grouper(repeat(train_data_loader), len(ctx_l)) # when args.epochs < 0, the model will keep training if args.epochs < 0: if args.max_update > 0: total_train_iters = args.max_update if args.num_averages > 0: assert args.num_averages <= total_train_iters // args.save_iterval_update avg_start_iter = ( total_train_iters // args.save_iterval_update - args.num_averages) * args.save_iterval_update else: avg_start_iter = -1 else: total_train_iters = np.inf avg_start_iter = -1 else: total_train_iters = args.epochs * num_updates_per_epoch if args.num_averages > 0: assert args.num_averages <= args.epochs avg_start_iter = (args.epochs - args.num_average) * num_updates_per_epoch else: avg_start_iter = -1 # Here, we are manually setting up the scale to 1.0 because # in horovod, the scale can be the number of workers: # See the code here: https://github.com/horovod/horovod/blob/125115583b7029196e2ec530decd4209459d5479/horovod/mxnet/__init__.py#L141 # Since we will need to use the dynamic scaling in amp, we will manually call amp.unscale(). # A scale that is larger than 1.0 can be problematic in this case. trainer._scale = 1.0 if args.max_num_tokens > 0: const_scale = args.max_num_tokens else: const_scale = 100 train_start_time = time.time() for train_iter in range(total_train_iters): model.zero_grad() loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] for i in range(args.num_accumulated): loss_l = [] sample_data_l = next(train_multi_data_loader) for j, (sample_data, ctx) in enumerate(zip(sample_data_l, ctx_l)): src_token_ids, tgt_token_ids, src_valid_length,\ tgt_valid_length, sample_ids = sample_data src_token_ids = src_token_ids.as_in_ctx(ctx) tgt_token_ids = tgt_token_ids.as_in_ctx(ctx) src_valid_length = src_valid_length.as_in_ctx(ctx) tgt_valid_length = tgt_valid_length.as_in_ctx(ctx) src_wc, tgt_wc, bs = src_valid_length.sum(), \ tgt_valid_length.sum(), src_token_ids.shape[0] log_wc_l[j] += src_wc + tgt_wc log_tgt_wc_l[j] += tgt_wc token_count = (tgt_valid_length - 1).sum() loss_denom_l[j] += token_count / const_scale log_avg_loss_denom_l[j] += token_count / const_scale with mx.autograd.record(): if model.layout == 'NT': tgt_pred = model(src_token_ids, src_valid_length, tgt_token_ids[:, :-1], tgt_valid_length - 1) tgt_labels = tgt_token_ids[:, 1:] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=1) loss = loss.sum() / const_scale loss_l.append(loss) elif model.layout == 'TN': tgt_pred = model(src_token_ids.T, src_valid_length, tgt_token_ids.T[:-1, :], tgt_valid_length - 1) tgt_labels = tgt_token_ids.T[1:, :] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=0) loss = loss.sum() / const_scale loss_l.append(loss) log_avg_loss_l[j] += loss if use_amp: with mx.autograd.record(): with amp.scale_loss(loss_l, trainer) as amp_loss_l: for loss in amp_loss_l: loss.backward() else: with mx.autograd.record(): for loss in loss_l: loss.backward() # Print the total number of parameters if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters(param_dict) logging.info( 'Total Number of Parameters (not-fixed/fixed): {}/{}'.format( num_params, num_fixed_params)) # All-Reduce the gradient trainer.allreduce_grads() if args.comm_backend == 'horovod': # All-Reduce the loss denominator assert len(loss_denom_l) == 1 loss_denom = hvd.allreduce(loss_denom_l[0], average=False).asnumpy() else: loss_denom = sum([ele.asnumpy() for ele in loss_denom_l]) if use_amp: # We need to first unscale the gradient and then perform allreduce. grad_scale = trainer.amp_loss_scale * loss_denom else: grad_scale = loss_denom if args.max_grad_norm is not None: total_norm, ratio, is_finite\ = clip_grad_global_norm(params, args.max_grad_norm * grad_scale) total_norm = total_norm / grad_scale else: total_norm = grad_global_norm(params) total_norm = total_norm / grad_scale log_avg_grad_norm += total_norm log_iter_num += 1 trainer.update(loss_denom, ignore_stale_grad=True) if avg_start_iter > 0 and train_iter >= avg_start_iter: model_averager.step() if ((train_iter + 1) % args.log_interval == 0 or train_iter + 1 == total_train_iters): if args.comm_backend == 'horovod': # Use allreduce to get the total number of tokens and loss log_wc = hvd.allreduce(log_wc_l[0], average=False).asnumpy() log_tgt_wc = hvd.allreduce(log_tgt_wc_l[0], average=False).asnumpy() log_avg_loss = hvd.allreduce(log_avg_loss_l[0] / log_avg_loss_denom_l[0], average=True) log_avg_loss = log_avg_loss.asnumpy() else: log_wc = sum([ele.asnumpy() for ele in log_wc_l]) log_tgt_wc = sum([ele.asnumpy() for ele in log_tgt_wc_l]) log_avg_loss =\ sum([log_avg_loss_l[i].asnumpy() / log_avg_loss_denom_l[i].asnumpy() for i in range(len(log_avg_loss_l))]) / len(log_avg_loss_l) log_avg_grad_norm = log_avg_grad_norm / log_iter_num log_end_time = time.time() wps = log_wc / (log_end_time - log_start_time) epoch_id = train_iter // num_updates_per_epoch logging.info( '[Epoch {} Iter {}/{}, Overall {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, total wc={:.2f}K, wpb={:.2f}K,' ' LR={}, gnorm={:.4f}, ETA={:.2f}h'.format( epoch_id, train_iter % num_updates_per_epoch + 1, num_updates_per_epoch, train_iter + 1, total_train_iters, log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, log_tgt_wc / 1000 / log_iter_num, trainer.learning_rate, log_avg_grad_norm, (log_end_time - train_start_time) / (train_iter + 1) * (total_train_iters - train_iter - 1) / 3600)) if local_rank == 0: writer.add_scalar('throughput_wps', wps, train_iter) writer.add_scalar('train_loss', log_avg_loss, train_iter) writer.add_scalar('lr', trainer.learning_rate, train_iter) writer.add_scalar('grad_norm', log_avg_grad_norm, train_iter) # Reinitialize the log variables log_start_time = time.time() log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_avg_grad_norm = 0 log_iter_num = 0 log_wc_l = [ mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l ] log_tgt_wc_l = [ mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l ] if (args.max_update > 0 and (train_iter + 1) % args.save_interval_update == 0) \ or ((train_iter + 1) % num_updates_per_epoch == 0) \ or train_iter + 1 == total_train_iters: epoch_id = (train_iter + 1) // num_updates_per_epoch if local_rank == 0: if args.max_update <= 0: model.save_parameters(os.path.join( args.save_dir, 'epoch{}.params'.format(epoch_id)), deduplicate=True) else: model.save_parameters(os.path.join( args.save_dir, 'iter{}.params'.format(train_iter + 1)), deduplicate=True) avg_val_loss, ntokens, pred_sentences, pred_lengths, sentence_ids\ = validation(model, val_data_loader, inference_model, beam_search_sampler, tgt_tokenizer, ctx_l) if args.comm_backend == 'horovod': flatten_pred_sentences = np.concatenate(pred_sentences, axis=0) all_val_loss = hvd.allgather( mx.np.array([avg_val_loss * ntokens], dtype=np.float32, ctx=ctx_l[0])) all_ntokens = hvd.allgather( mx.np.array([ntokens], dtype=np.int64, ctx=ctx_l[0])) flatten_pred_sentences = hvd.allgather( mx.np.array(flatten_pred_sentences, dtype=np.int32, ctx=ctx_l[0])) pred_lengths = hvd.allgather( mx.np.array(pred_lengths, dtype=np.int64, ctx=ctx_l[0])) sentence_ids = hvd.allgather( mx.np.array(sentence_ids, dtype=np.int64, ctx=ctx_l[0])) avg_val_loss = all_val_loss.asnumpy().sum( ) / all_ntokens.asnumpy().sum() flatten_pred_sentences = flatten_pred_sentences.asnumpy() pred_lengths = pred_lengths.asnumpy() sentence_ids = sentence_ids.asnumpy() pred_sentences = [None for _ in range(len(sentence_ids))] ptr = 0 assert sentence_ids.min() == 0 and sentence_ids.max( ) == len(sentence_ids) - 1 for sentence_id, length in zip(sentence_ids, pred_lengths): pred_sentences[sentence_id] = flatten_pred_sentences[ptr:( ptr + length)] ptr += length if local_rank == 0: # Perform detokenization pred_sentences_bpe_decode = [] pred_sentences_raw = [] for sentence in pred_sentences: bpe_decode_sentence = tgt_tokenizer.decode( sentence.tolist()) raw_sentence = base_tgt_tokenizer.decode( bpe_decode_sentence.split()) pred_sentences_bpe_decode.append(bpe_decode_sentence) pred_sentences_raw.append(raw_sentence) detok_sacrebleu_out = sacrebleu.corpus_bleu( sys_stream=pred_sentences_bpe_decode, ref_streams=[tgt_detok_sentences]) raw_sacrebleu_out = sacrebleu.corpus_bleu( sys_stream=pred_sentences_raw, ref_streams=[tgt_raw_sentences]) with open( os.path.join(args.save_dir, f'epoch{epoch_id}_dev_prediction.txt'), 'w') as of: for line in pred_sentences_raw: of.write(line + '\n') logging.info( '[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, ' 'SacreBlEU={}, Detok SacreBLUE={}'.format( epoch_id, train_iter, total_train_iters, avg_val_loss, np.exp(avg_val_loss), raw_sacrebleu_out.score, detok_sacrebleu_out.score)) writer.add_scalar('valid_loss', avg_val_loss, train_iter) writer.add_scalar('valid_bleu', raw_sacrebleu_out.score, train_iter) if args.num_averages > 0: model_averager.copy_back( param_dict) # TODO(sxjscience) Rewrite using update model.save_parameters(os.path.join(args.save_dir, 'average.params'), deduplicate=True)
def train(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) setup_logging(args, local_rank) cfg, tokenizer, qa_net, use_segmentation = \ get_network(args.model_name, ctx_l, args.classifier_dropout, args.param_checkpoint, args.backbone_path) logging.info('Prepare training data') train_features = get_squad_features(args, tokenizer, segment='train') dataset_processor = SquadDatasetProcessor( tokenizer=tokenizer, doc_stride=args.doc_stride, max_seq_length=args.max_seq_length, max_query_length=args.max_query_length) logging.info('Processing the Training data:') train_dataset, num_answer_mismatch, num_unreliable \ = dataset_processor.get_train(train_features, skip_unreliable=True) logging.info( 'Done! #Unreliable Span={} / #Mismatched Answer={} / #Total={}'.format( num_unreliable, num_answer_mismatch, len(train_features))) # Get dataset statistics num_impossible = 0 for sample in train_dataset: num_impossible += sample.is_impossible logging.info('Before Chunking, #Train/Is Impossible = {}/{}'.format( len(train_features), sum([ele.is_impossible for ele in train_features]))) logging.info('After Chunking, #Train Sample/Is Impossible = {}/{}'.format( len(train_dataset), num_impossible)) # Shuffle the dataset using a fixed seed across all workers rs = np.random.RandomState(args.pre_shuffle_seed) rs.shuffle(train_dataset) sampler = SplitSampler(len(train_dataset), num_parts=num_workers, part_index=rank, even_size=True) train_dataloader = mx.gluon.data.DataLoader( train_dataset, batchify_fn=dataset_processor.BatchifyFunction, batch_size=args.batch_size, num_workers=0, sampler=sampler) if 'electra' in args.model_name: # Froze parameters, does not work for albert model since parameters in all layers are shared if args.untunable_depth > 0: qa_net.backbone.frozen_params(args.untunable_depth) if args.layerwise_decay > 0: qa_net.backbone.apply_layerwise_decay(args.layerwise_decay) logging.info('Creating distributed trainer...') # Collect differentiable parameters param_dict = qa_net.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in qa_net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 params = [p for p in param_dict.values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required num_accumulated = args.num_accumulated if num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' # backend specific implementation if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) epoch_size = (len(train_dataloader) + len(ctx_l) - 1) // len(ctx_l) if args.num_train_steps is not None: num_train_steps = args.num_train_steps else: num_train_steps = int(args.epochs * epoch_size / args.num_accumulated) if args.warmup_steps is not None: warmup_steps = args.warmup_steps else: warmup_steps = int(num_train_steps * args.warmup_ratio) assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio' log_interval = args.log_interval save_interval = args.save_interval if args.save_interval is not None\ else epoch_size // args.num_accumulated logging.info( '#Total Training Steps={}, Warmup={}, Save Interval={}'.format( num_train_steps, warmup_steps, save_interval)) # set up optimization lr_scheduler = PolyScheduler(max_update=num_train_steps, base_lr=args.lr, warmup_begin_lr=0, pwr=1, final_lr=0, warmup_steps=warmup_steps, warmup_mode='linear') optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, } adam_betas = eval(args.adam_betas) if args.optimizer == 'adamw': optimizer_params.update({ 'beta1': adam_betas[0], 'beta2': adam_betas[1], 'epsilon': args.adam_epsilon, 'correct_bias': False, }) elif args.optimizer == 'adam': optimizer_params.update({ 'beta1': adam_betas[0], 'beta2': adam_betas[1], 'epsilon': args.adam_epsilon, }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) log_span_loss = 0 log_answerable_loss = 0 log_total_loss = 0 log_sample_num = 0 global_tic = time.time() tic = time.time() for step_num, batch_data in enumerate( grouper(repeat(train_dataloader), len(ctx_l) * num_accumulated)): for sample_l in grouper(batch_data, len(ctx_l)): loss_l = [] span_loss_l = [] answerable_loss_l = [] for sample, ctx in zip(sample_l, ctx_l): if sample is None: continue # Copy the data to device tokens = sample.data.as_in_ctx(ctx) log_sample_num += len(tokens) segment_ids = sample.segment_ids.as_in_ctx( ctx) if use_segmentation else None valid_length = sample.valid_length.as_in_ctx(ctx) p_mask = sample.masks.as_in_ctx(ctx) gt_start = sample.gt_start.as_in_ctx(ctx).astype(np.int32) gt_end = sample.gt_end.as_in_ctx(ctx).astype(np.int32) is_impossible = sample.is_impossible.as_in_ctx(ctx).astype( np.int32) batch_idx = mx.np.arange(tokens.shape[0], dtype=np.int32, ctx=ctx) p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask with mx.autograd.record(): start_logits, end_logits, answerable_logits \ = qa_net(tokens, segment_ids, valid_length, p_mask, gt_start) sel_start_logits = start_logits[batch_idx, gt_start] sel_end_logits = end_logits[batch_idx, gt_end] sel_answerable_logits = answerable_logits[batch_idx, is_impossible] span_loss = -0.5 * (sel_start_logits + sel_end_logits).mean() answerable_loss = -0.5 * sel_answerable_logits.mean() loss = span_loss + answerable_loss loss_l.append(loss) span_loss_l.append(span_loss) answerable_loss_l.append(answerable_loss) for loss in loss_l: loss.backward() # All Reduce the Step Loss log_span_loss += sum( [ele.as_in_ctx(ctx_l[0]) for ele in span_loss_l]).asnumpy() log_total_loss += sum([ele.as_in_ctx(ctx_l[0]) for ele in loss_l]).asnumpy() log_answerable_loss += sum([ ele.as_in_ctx(ctx_l[0]) for ele in answerable_loss_l ]).asnumpy() # update trainer.allreduce_grads() if args.max_grad_norm > 0: total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_workers) else: total_norm = grad_global_norm(params) if args.comm_backend == 'horovod': # Note that horovod.trainer._scale is default to num_workers, # thus trainer.update(1) will scale the gradients by 1./num_workers trainer.update(1, ignore_stale_grad=True) else: # gluon.trainer._scale is default to 1 trainer.update(num_workers, ignore_stale_grad=True) total_norm = total_norm / num_workers if args.num_accumulated > 1: # set grad to zero for gradient accumulation qa_net.zero_grad() # saving if local_rank == 0 and (step_num + 1) % save_interval == 0 or ( step_num + 1) >= num_train_steps: version_prefix = 'squad' + args.version ckpt_name = '{}_{}_{}.params'.format(args.model_name, version_prefix, (step_num + 1)) params_saved = os.path.join(args.output_dir, ckpt_name) qa_net.save_parameters(params_saved) ckpt_candidates = [ f for f in os.listdir(args.output_dir) if f.endswith('.params') ] # keep last `max_saved_ckpt` checkpoints if len(ckpt_candidates) > args.max_saved_ckpt: ckpt_candidates.sort(key=lambda ele: (len(ele), ele)) os.remove(os.path.join(args.output_dir, ckpt_candidates[0])) logging.info('Params saved in: {}'.format(params_saved)) # logging if (step_num + 1) % log_interval == 0: log_span_loss /= log_sample_num log_answerable_loss /= log_sample_num log_total_loss /= log_sample_num toc = time.time() logging.info( 'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},' ' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s' ' ETA={:.2f}h'.format( (step_num + 1), num_train_steps, log_span_loss, log_answerable_loss, log_total_loss, trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic), (num_train_steps - (step_num + 1)) / ((step_num + 1) / (toc - global_tic)) / 3600)) tic = time.time() log_span_loss = 0 log_answerable_loss = 0 log_total_loss = 0 log_sample_num = 0 num_samples_per_update = 0 if (step_num + 1) >= num_train_steps: toc = time.time() logging.info('Finish training step: {} within {} hours'.format( step_num + 1, (toc - global_tic) / 3600)) break return params_saved
def train(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) task = get_task(args.task_name) #setup_logging(args, local_rank) level = logging.INFO detail_dir = os.path.join(args.output_dir, args.task_name) if not os.path.exists(detail_dir): os.mkdir(detail_dir) logging_config( detail_dir, name='train_{}_{}_'.format(args.task_name, args.model_name) + str(rank), # avoid race level=level, console=(local_rank == 0)) logging.info(args) cfg, tokenizer, classify_net, use_segmentation = \ get_network(args.model_name, ctx_l, args.param_checkpoint, args.backbone_path, task) logging.info('Prepare training data') train_data, _ = get_task_data(args, tokenizer, segment='train', task=task) train_batchify = bf.Group(bf.Group(bf.Pad(), bf.Pad(), bf.Stack()), bf.Stack()) epoch_num_updates = len(train_data) // args.batch_size max_update = epoch_num_updates * args.epochs warmup_steps = int(np.ceil(max_update * args.warmup_ratio)) dataloader = DataLoader(train_data, batch_size=args.batch_size, batchify_fn=train_batchify, num_workers=4, shuffle=True) dataloader = grouper(repeat(dataloader), len(ctx_l)) param_dict = classify_net.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in classify_net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Set grad_req if gradient accumulation is required params = [p for p in param_dict.values() if p.grad_req != 'null'] num_accumulated = args.num_accumulated if num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) lr_scheduler = PolyScheduler(max_update=max_update, base_lr=args.lr, warmup_begin_lr=0.0, pwr=1, final_lr=0.0, warmup_steps=warmup_steps, warmup_mode='linear') optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler } if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(classify_net.collect_params(), 'adamw', optimizer_params) if args.task_name == 'sts': loss_function = gluon.loss.L2Loss() else: loss_function = gluon.loss.SoftmaxCELoss() #prepare loss function log_loss = 0 log_gnorm = 0 log_step = 0 if args.log_interval > 0: log_interval = args.log_interval else: log_interval = int(epoch_num_updates * 0.5) for i in range(max_update): sample_l = next(dataloader) loss_l = [] for sample, ctx in zip(sample_l, ctx_l): (token_ids, token_types, valid_length), label = sample # Move to the corresponding context token_ids = mx.np.array(token_ids, ctx=ctx) token_types = mx.np.array(token_types, ctx=ctx) valid_length = mx.np.array(valid_length, ctx=ctx) label = mx.np.array(label, ctx=ctx) with mx.autograd.record(): scores = classify_net(token_ids, token_types, valid_length) loss = loss_function(scores, label).mean() / len(ctx_l) loss_l.append(loss) for loss in loss_l: loss.backward() trainer.allreduce_grads() # Begin Norm Clipping total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm) trainer.update(1.0) step_loss = sum([loss.asnumpy() for loss in loss_l]) log_loss += step_loss log_gnorm += total_norm log_step += 1 if log_step >= log_interval or i == max_update - 1: logging.info( '[Iter {} / {}] avg {} = {:.2f}, avg gradient norm = {:.2f}'. format(i + 1, max_update, 'nll', log_loss / log_step, log_gnorm / log_step)) log_loss = 0 log_gnorm = 0 log_step = 0 if local_rank == 0 and (i == max_update - 1 or i % (max_update // args.epochs) == 0 and i > 0): ckpt_name = '{}_{}_{}.params'.format(args.model_name, args.task_name, (i + 1)) params_saved = os.path.join(detail_dir, ckpt_name) classify_net.save_parameters(params_saved) logging.info('Params saved in: {}'.format(params_saved))
def train(args): _, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) level = logging.DEBUG if args.verbose else logging.INFO logging_config( args.ckpt_dir, name='pretrain_bert_' + str(rank), # avoid race level=level, console=(local_rank == 0)) logging.info(args) logging.debug('Random seed set to {}'.format(args.seed)) set_seed(args.seed) logging.info('Training info: num_buckets: {}, ' 'num_workers: {}, rank: {}'.format(args.num_buckets, num_workers, rank)) cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) parameters_option(args.start_step, model, args.ckpt_dir, 'Loading', ctx_l) else: model.initialize(ctx=ctx_l) model.hybridize() if args.raw: get_dataset_fn = functools.partial( get_pretrain_data_text, max_seq_length=args.max_seq_length, short_seq_prob=args.short_seq_prob, masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, whole_word_mask=args.whole_word_mask, random_next_sentence=args.random_next_sentence, tokenizer=tokenizer, circle_length=args.circle_length, repeat=args.repeat, dataset_cached=args.dataset_cached, num_max_dataset_cached=args.num_max_dataset_cached) else: get_dataset_fn = get_pretrain_data_npz data_train = get_dataset_fn(args.data, args.batch_size, shuffle=True, num_buckets=args.num_buckets, vocab=tokenizer.vocab, num_parts=num_workers, part_idx=rank, num_dataset_workers=args.num_dataset_workers, num_batch_workers=args.num_batch_workers) param_dict = model.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Set grad_req if gradient accumulation is required params = [p for p in param_dict.values() if p.grad_req != 'null'] num_accumulated = args.num_accumulated if num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' num_steps = args.num_steps warmup_steps = int(num_steps * args.warmup_ratio) log_interval = args.log_interval save_interval = args.ckpt_interval logging.info( '#Total Training Steps={}, Warmup Steps={}, Save Interval={}'.format( num_steps, warmup_steps, save_interval)) optimizer_params = {'learning_rate': args.lr, 'wd': args.wd} if args.optimizer == 'adamw': optimizer_params.update({ 'beta1': 0.9, 'beta2': 0.999, 'epsilon': 1e-6, 'correct_bias': False, }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) elif args.comm_backend == 'byteps': trainer = bps.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) states_option(args.start_step, trainer, args.ckpt_dir, local_rank, 'Loading') # backend specific implementation if args.comm_backend == 'byteps': trainer._init_params() if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) # prepare the loss function nsp_loss_fn = mx.gluon.loss.SoftmaxCELoss() mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss() nsp_loss_fn.hybridize() mlm_loss_fn.hybridize() mlm_metric = MaskedAccuracy() nsp_metric = MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() step_num = args.start_step if args.phase2: step_num -= args.phase1_num_steps running_mlm_loss, running_nsp_loss = 0., 0. running_num_tks = 0 train_start_time = time.time() tic = time.time() # start training train_loop_dataloader = grouper(repeat(data_train), len(ctx_l)) while step_num < num_steps: step_num += 1 for _ in range(num_accumulated): sample_l = next(train_loop_dataloader) mlm_loss_l = [] nsp_loss_l = [] loss_l = [] ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] for sample, ctx in zip(sample_l, ctx_l): # prepare data (input_id, masked_id, masked_position, masked_weight, \ next_sentence_label, segment_id, valid_length) = sample input_id = input_id.as_in_ctx(ctx) masked_id = masked_id.as_in_ctx(ctx) masked_position = masked_position.as_in_ctx(ctx) masked_weight = masked_weight.as_in_ctx(ctx) next_sentence_label = next_sentence_label.as_in_ctx(ctx) segment_id = segment_id.as_in_ctx(ctx) valid_length = valid_length.as_in_ctx(ctx) with mx.autograd.record(): _, _, nsp_score, mlm_scores = model( input_id, segment_id, valid_length, masked_position) denominator = (masked_weight.sum() + 1e-8) * num_accumulated * len(ctx_l) mlm_scores_r = mx.npx.reshape(mlm_scores, (-5, -1)) masked_id_r = masked_id.reshape((-1, )) mlm_loss = mlm_loss_fn(mlm_scores_r, masked_id_r, masked_weight.reshape( (-1, 1))).sum() / denominator denominator = num_accumulated * len(ctx_l) nsp_loss = nsp_loss_fn( nsp_score, next_sentence_label).mean() / denominator mlm_loss_l.append(mlm_loss) nsp_loss_l.append(nsp_loss) loss_l.append(mlm_loss + nsp_loss) mask_label_list.append(masked_id_r) mask_pred_list.append(mlm_scores_r) mask_weight_list.append(masked_weight.reshape((-1, ))) ns_label_list.append(next_sentence_label) ns_pred_list.append(nsp_score) running_num_tks += valid_length.sum().as_in_ctx(mx.cpu()) for loss in loss_l: loss.backward() running_mlm_loss += sum([ ele.as_in_ctx(mx.cpu()) for ele in mlm_loss_l ]).asnumpy().item() running_nsp_loss += sum([ ele.as_in_ctx(mx.cpu()) for ele in nsp_loss_l ]).asnumpy().item() mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) nsp_metric.update(ns_label_list, ns_pred_list) # update trainer.allreduce_grads() total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_workers) total_norm = total_norm / num_workers # update learning rate scheduled_lr = args.lr if step_num <= warmup_steps: scheduled_lr *= step_num / warmup_steps else: offset = (num_steps - step_num) / (num_steps - warmup_steps) scheduled_lr *= max(offset, 0) trainer.set_learning_rate(scheduled_lr) if args.comm_backend == 'horovod' or args.comm_backend == 'byteps': # Note that horovod.trainer._scale is default to num_workers, # thus trainer.update(1) will scale the gradients by 1./num_workers. # *num_workers* of Horovod is the number of GPUs. trainer.update(1, ignore_stale_grad=True) else: # gluon.trainer._scale is default to 1. # *num_workers* of Trainer is the number of machines. trainer.update(num_workers, ignore_stale_grad=True) if num_accumulated > 1: # set grad to zero for gradient accumulation model.zero_grad() # saving if step_num % save_interval == 0 or step_num >= num_steps: states_option(step_num, trainer, args.ckpt_dir, local_rank, 'Saving') if local_rank == 0: parameters_option(step_num, model, args.ckpt_dir, 'Saving') # logging if step_num % log_interval == 0: running_mlm_loss /= log_interval running_nsp_loss /= log_interval toc = time.time() logging.info( '[step {}], Loss mlm/nsp={:.5f}/{:.3f}, Acc mlm/nsp={:.3f}/{:.3f}, ' ' LR={:.7f}, grad_norm={:.4f}. Time cost={:.2f} s,' ' Throughput={:.1f}K tks/s, ETA={:.2f} h'.format( step_num, running_mlm_loss, running_nsp_loss, mlm_metric.get()[1], nsp_metric.get()[1], trainer.learning_rate, total_norm, toc - tic, running_num_tks.asnumpy().item() / (toc - tic) / 1000, (num_steps - step_num) / (step_num / (toc - train_start_time)) / 3600)) mlm_metric.reset() nsp_metric.reset() tic = time.time() running_mlm_loss = 0 running_nsp_loss = 0 running_num_tks = 0 logging.info('Finish training step: %d', step_num) mx.npx.waitall() train_end_time = time.time() logging.info('Train cost={:.1f} s'.format(train_end_time - train_start_time)) if local_rank == 0: model_name = args.model_name.replace('google', 'gluon') save_dir = os.path.join(args.ckpt_dir, model_name) final_save(model, save_dir, tokenizer, cfg)
def train(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) logging.info('Training info: num_buckets: {}, ' 'num_workers: {}, rank: {}'.format(args.num_buckets, num_workers, rank)) cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l, args.max_seq_length, args.hidden_dropout_prob, args.attention_dropout_prob, args.generator_units_scale, args.generator_layers_scale) data_masker = ElectraMasker(tokenizer, args.max_seq_length, args.mask_prob) if args.from_raw_text: if args.cached_file_path and not os.path.exists(args.cached_file_path): os.mkdir(args.cached_file_path) get_dataset_fn = functools.partial( get_pretrain_data_text, max_seq_length=args.max_seq_length, short_seq_prob=args.short_seq_prob, tokenizer=tokenizer, circle_length=args.circle_length, repeat=args.repeat, cached_file_path=args.cached_file_path) logging.info( 'Processing and loading the training dataset from raw text.') else: logging.info('Loading the training dataset from local Numpy file.') get_dataset_fn = get_pretrain_data_npz data_train = get_dataset_fn(args.data, args.batch_size, shuffle=True, num_buckets=args.num_buckets, vocab=tokenizer.vocab, num_parts=num_workers, part_idx=rank, num_dataset_workers=args.num_dataset_workers, num_batch_workers=args.num_batch_workers) logging.info('Creating distributed trainer...') param_dict = model.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 # Collect differentiable parameters params = [p for p in param_dict.values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required if args.num_accumulated > 1: logging.info( 'Using gradient accumulation. Effective global batch size = {}'. format(args.num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' # backend specific implementation if args.comm_backend == 'horovod': # Horovod: fetch and broadcast parameters hvd.broadcast_parameters(param_dict, root_rank=0) num_train_steps = args.num_train_steps if args.warmup_steps is not None: warmup_steps = args.warmup_steps else: warmup_steps = int(num_train_steps * args.warmup_ratio) assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio' log_interval = args.log_interval save_interval = args.save_interval if args.save_interval is not None\ else num_train_steps // 50 logging.info( '#Total Training Steps={}, Warmup={}, Save Interval={}'.format( num_train_steps, warmup_steps, save_interval)) lr_scheduler = PolyScheduler(max_update=num_train_steps, base_lr=args.lr, warmup_begin_lr=0, pwr=1, final_lr=0, warmup_steps=warmup_steps, warmup_mode='linear') optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, } if args.optimizer == 'adamw': optimizer_params.update({ 'beta1': 0.9, 'beta2': 0.999, 'epsilon': 1e-6, 'correct_bias': False, }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) # TODO(zheyuye), How about data splitting, where to start re-training state_path = states_option(args.start_step, trainer, args.output_dir, local_rank, 'Loading') param_path = parameters_option(args.start_step, model, args.output_dir, 'Loading') # prepare the loss function mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss() rtd_loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() mlm_loss_fn.hybridize() rtd_loss_fn.hybridize() # prepare the records writer writer = None if args.do_eval and local_rank == 0: from tensorboardX import SummaryWriter record_path = os.path.join(args.output_dir, 'records') logging.info('Evaluation records saved in {}'.format(record_path)) writer = SummaryWriter(record_path) step_num = args.start_step finish_flag = False num_samples_per_update = 0 loss_denom = float(len(ctx_l) * args.num_accumulated * num_workers) log_total_loss = 0 log_mlm_loss = 0 log_rtd_loss = 0 log_sample_num = 0 train_start_time = time.time() if args.num_accumulated != 1: # set grad to zero for gradient accumulation model.zero_grad() # start training train_loop_dataloader = grouper(repeat(data_train), len(ctx_l)) while step_num < num_train_steps: tic = time.time() for accum_idx in range(args.num_accumulated): sample_l = next(train_loop_dataloader) loss_l = [] mlm_loss_l = [] rtd_loss_l = [] for sample, ctx in zip(sample_l, ctx_l): if sample is None: continue # prepare data input_ids, segment_ids, valid_lengths = sample input_ids = input_ids.as_in_ctx(ctx) segment_ids = segment_ids.as_in_ctx(ctx) valid_lengths = valid_lengths.as_in_ctx(ctx) masked_input = data_masker.dynamic_masking( mx.nd, input_ids, valid_lengths) masked_input_ids = masked_input.input_ids length_masks = masked_input.masks unmasked_tokens = masked_input.unmasked_tokens masked_positions = masked_input.masked_positions masked_weights = masked_input.masked_weights log_sample_num += len(masked_input_ids) num_samples_per_update += len(masked_input_ids) with mx.autograd.record(): mlm_scores, rtd_scores, corrupted_tokens, labels = model( masked_input_ids, segment_ids, valid_lengths, unmasked_tokens, masked_positions) # the official implementation takes the sum of each batch inside the loss function # while SigmoidBinaryCrossEntropyLoss and SoftmaxCELoss takes the mean value mlm_loss = mlm_loss_fn( mlm_scores, unmasked_tokens, masked_weights.reshape( -1)).mean() / (masked_weights.mean() + 1e-6) rtd_loss = rtd_loss_fn( rtd_scores, labels, length_masks).mean() / (length_masks.mean() + 1e-6) output = ElectraOutput( mlm_scores=mlm_scores, rtd_scores=rtd_scores, rtd_labels=labels, corrupted_tokens=corrupted_tokens, ) mlm_loss_l.append(mlm_loss) rtd_loss_l.append(rtd_loss) loss = (args.gen_weight * mlm_loss + args.disc_weight * rtd_loss) / loss_denom loss_l.append(loss) for loss in loss_l: loss.backward() # All Reduce the Step Loss log_mlm_loss += sum( [ele.as_in_ctx(ctx_l[0]) for ele in mlm_loss_l]).asnumpy() log_rtd_loss += sum( [ele.as_in_ctx(ctx_l[0]) for ele in rtd_loss_l]).asnumpy() log_total_loss += sum([ele.as_in_ctx(ctx_l[0]) for ele in loss_l]).asnumpy() * loss_denom # update trainer.allreduce_grads() # Here, the accumulated gradients are # \sum_{n=1}^N g_n / loss_denom # Thus, in order to clip the average gradient # \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm # We need to change the ratio to be # \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom total_norm, ratio, is_finite = clip_grad_global_norm( params, args.max_grad_norm * num_samples_per_update / loss_denom) total_norm = total_norm / (num_samples_per_update / loss_denom) trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True) step_num += 1 if args.num_accumulated != 1: # set grad to zero for gradient accumulation model.zero_grad() # saving if step_num % save_interval == 0 or step_num >= num_train_steps: if is_master_node: states_option(step_num, trainer, args.output_dir, local_rank, 'Saving') if local_rank == 0: param_path = parameters_option(step_num, model, args.output_dir, 'Saving') # logging if step_num % log_interval == 0 and local_rank == 0: # Output the loss of per step log_mlm_loss /= log_interval log_rtd_loss /= log_interval log_total_loss /= log_interval toc = time.time() logging.info('[step {}], Loss mlm/rtd/total={:.4f}/{:.4f}/{:.4f},' ' LR={:.6f}, grad_norm={:.4f}. Time cost={:.2f},' ' Throughput={:.2f} samples/s, ETA={:.2f}h'.format( step_num, log_mlm_loss, log_rtd_loss, log_total_loss, trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic), (num_train_steps - step_num) / (step_num / (toc - train_start_time)) / 3600)) tic = time.time() if args.do_eval: evaluation(writer, step_num, masked_input, output) writer.add_scalars( 'loss', { 'total_loss': log_total_loss, 'mlm_loss': log_mlm_loss, 'rtd_loss': log_rtd_loss }, step_num) log_mlm_loss = 0 log_rtd_loss = 0 log_total_loss = 0 log_sample_num = 0 num_samples_per_update = 0 logging.info('Finish training step: %d', step_num) if is_master_node: state_path = states_option(step_num, trainer, args.output_dir, local_rank, 'Saving') if local_rank == 0: param_path = parameters_option(step_num, model, args.output_dir, 'Saving') mx.npx.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_start_time)) if writer is not None: writer.close() if local_rank == 0: model_name = args.model_name.replace('google', 'gluon') save_dir = os.path.join(args.output_dir, model_name) final_save(model, save_dir, tokenizer)