def test_transformer_param_number(cfg_name, gt_num_params, gt_num_fixed_params): cfg = TransformerModel.get_cfg(cfg_name) cfg.defrost() cfg.MODEL.src_vocab_size = 32768 cfg.MODEL.tgt_vocab_size = 32768 cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize() num_params, num_fixed_params = count_parameters(model.collect_params()) assert num_params == gt_num_params assert num_fixed_params == gt_num_fixed_params num_params2, num_fixed_params2 = count_parameters(deduplicate_param_dict(model.collect_params())) assert num_params2 == gt_num_params assert num_fixed_params2 == gt_num_fixed_params
def test_get_backbone(name, ctx): with tempfile.TemporaryDirectory() as root, ctx: model_cls, cfg, tokenizer, local_params_path, _ = get_backbone( name, root=root) net = model_cls.from_cfg(cfg) net.load_parameters(local_params_path) net.hybridize() num_params, num_fixed_params = count_parameters(net.collect_params()) assert num_params > 0 # Test for model export + save if 'gpt2' in name: pytest.skip('Skipping GPT-2 test') batch_size = 1 sequence_length = 4 inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length)) token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length)) valid_length = mx.np.random.randint(1, sequence_length, (batch_size, )) if 'roberta' in name: out = net(inputs, valid_length) elif 'xlmr' in name: out = net(inputs, valid_length) elif 'bart' in name: out = net(inputs, valid_length, inputs, valid_length) elif 'gpt2' in name: states = net.init_states(batch_size=batch_size, ctx=ctx) out, new_states = net(inputs, states) out_np = out.asnumpy() else: out = net(inputs, token_types, valid_length) mx.npx.waitall() net.export(os.path.join(root, 'model'))
def get_network(model_name, ctx_l, checkpoint_path=None, backbone_path=None, task=None): """ Get the network that fine-tune the Question Answering Task """ use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name Model, cfg, tokenizer, download_params_path, _ = \ get_backbone(model_name, load_backbone=not backbone_path) backbone = Model.from_cfg(cfg) # Load local backbone parameters if backbone_path provided. # Otherwise, download backbone parameters from gluon zoo. backbone_params_path = backbone_path if backbone_path else download_params_path if checkpoint_path is None: backbone.load_parameters(backbone_params_path, ignore_extra=True, ctx=ctx_l, cast_dtype=True) num_params, num_fixed_params \ = count_parameters(deduplicate_param_dict(backbone.collect_params())) logging.info( 'Loading Backbone Model from {}, with total/fixd parameters={}/{}'. format(backbone_params_path, num_params, num_fixed_params)) classify_net = TextPredictionNet(backbone, task.class_num) if checkpoint_path is None: # Ignore the UserWarning during initialization, # There is no need to re-initialize the parameters of backbone classify_net.initialize(ctx=ctx_l) else: classify_net.load_parameters(checkpoint_path, ctx=ctx_l, cast_dtype=True) classify_net.hybridize() return cfg, tokenizer, classify_net, use_segmentation
def train(args): _, num_parts, rank, local_rank, _, ctx_l = init_comm( args.comm_backend, args.gpus) if args.comm_backend == 'horovod': logging_config( args.save_dir, name=f'train_transformer_rank{rank}_local{local_rank}_{num_parts}', console=(rank == 0)) logging.info(args) else: logging_config(args.save_dir, name='train_transformer', console=True) logging.info(args) use_amp = args.fp16 if use_amp: from mxnet import amp src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) tgt_tokenizer = create_tokenizer(args.tgt_tokenizer, args.tgt_subword_model_path, args.tgt_vocab_path) base_tgt_tokenizer = MosesTokenizer(args.tgt_lang) src_vocab = src_tokenizer.vocab tgt_vocab = tgt_tokenizer.vocab train_src_data, train_tgt_data = load_dataset_with_cache( args.train_src_corpus, args.train_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache, local_rank, max_src_length=args.max_src_length, max_tgt_length=args.max_tgt_length, pretokenized=not args.tokenize) dev_src_data, dev_tgt_data = load_dataset_with_cache( args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache, local_rank, pretokenized=not args.tokenize) tgt_detok_sentences = [] tgt_raw_sentences = [] with open(args.dev_tgt_corpus, 'r') as in_f: for line in in_f: tgt_detok_sentences.append( base_tgt_tokenizer.decode( tgt_tokenizer.decode(line.split()).split())) with open(args.dev_tgt_raw_corpus, 'r') as in_f: for line in in_f: tgt_raw_sentences.append(line.strip()) data_train = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data)) ]) val_samples = [ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data)) ] if args.comm_backend == 'horovod': slice_begin = rank * (len(val_samples) // num_parts) slice_end = min((rank + 1) * (len(val_samples) // num_parts), len(val_samples)) data_val = gluon.data.SimpleDataset(val_samples[slice_begin:slice_end]) else: data_val = gluon.data.SimpleDataset(val_samples) # Construct the model + loss function if args.cfg.endswith('.yml'): cfg = TransformerModel.get_cfg().clone_merge(args.cfg) else: cfg = TransformerModel.get_cfg(args.cfg) cfg.defrost() cfg.MODEL.src_vocab_size = len(src_vocab) cfg.MODEL.tgt_vocab_size = len(tgt_vocab) cfg.MODEL.layout = 'TN' cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() for v in model.collect_params().values(): if v.grad_req != 'null': v.grad_req = 'add' # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 param_dict = deduplicate_param_dict(model.collect_params()) inference_model = TransformerInference(model=model) inference_model.hybridize() if local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) label_smooth_loss = LabelSmoothCrossEntropyLoss( num_labels=len(tgt_vocab), alpha=args.label_smooth_alpha, from_logits=False) label_smooth_loss.hybridize() # Construct the beam search sampler scorer = BeamSearchScorer(alpha=args.lp_alpha, K=args.lp_k, from_logits=False) beam_search_sampler = BeamSearchSampler(beam_size=args.beam_size, decoder=inference_model, vocab_size=len(tgt_vocab), eos_id=tgt_vocab.eos_id, scorer=scorer, stochastic=False, max_length_a=args.max_length_a, max_length_b=args.max_length_b) logging.info(beam_search_sampler) if args.comm_backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) # Construct the trainer if args.lr is None: base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt( args.warmup_steps) else: base_lr = args.lr lr_scheduler = InverseSquareRootScheduler( warmup_steps=args.warmup_steps, base_lr=base_lr, warmup_init_lr=args.warmup_init_lr) optimizer_params = { 'learning_rate': args.lr, 'beta1': 0.9, 'beta2': 0.997, 'epsilon': 1e-9, 'lr_scheduler': lr_scheduler, 'wd': args.wd } user_provided_ptimizer_params = json.loads(args.optimizer_params) optimizer_params.update(user_provided_ptimizer_params) if args.fp16: optimizer_params.update({'multi_precision': True}) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler( lengths=[(ele[2], ele[3]) for ele in data_train], max_num_tokens=args.max_num_tokens, max_num_sentences=args.max_num_sentences, shuffle=True, seed=args.seed) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': raise NotImplementedError( 'FixedBucketSampler does not support horovod at present') if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError # TODO(sxjscience) Support auto-bucket-size tuning train_batch_sampler = FixedBucketSampler(lengths=[ (ele[2], ele[3]) for ele in data_train ], batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True, bucket_scheme=bucket_scheme, seed=args.seed) else: raise NotImplementedError num_updates_per_epoch = int( math.ceil( len(train_batch_sampler) / (num_parts * len(ctx_l) * args.num_accumulated))) # Convert the batch sampler to multiple shards if num_parts > 1: train_batch_sampler = ShardedIterator(train_batch_sampler, num_parts=num_parts, part_index=rank, even_size=True, seed=args.seed + 1000 * rank) logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader( data_train, batch_sampler=train_batch_sampler, batchify_fn=batchify_fn, num_workers=0) val_data_loader = gluon.data.DataLoader(data_val, batch_size=args.val_batch_size, batchify_fn=batchify_fn, num_workers=0, shuffle=False) params = [p for p in param_dict.values() if p.grad_req != 'null'] model_averager = AverageSGDTracker(param_dict) log_start_time = time.time() num_params, num_fixed_params = None, None # TODO(sxjscience) Add a log metric class log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] # Maintain the denominator of the loss. log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l] log_tgt_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l] log_avg_grad_norm = 0 log_iter_num = 0 if local_rank == 0: writer = SummaryWriter( logdir=os.path.join(args.save_dir, 'tensorboard')) if use_amp: amp.init_trainer(trainer) train_multi_data_loader = grouper(repeat(train_data_loader), len(ctx_l)) # when args.epochs < 0, the model will keep training if args.epochs < 0: if args.max_update > 0: total_train_iters = args.max_update if args.num_averages > 0: assert args.num_averages <= total_train_iters // args.save_iterval_update avg_start_iter = ( total_train_iters // args.save_iterval_update - args.num_averages) * args.save_iterval_update else: avg_start_iter = -1 else: total_train_iters = np.inf avg_start_iter = -1 else: total_train_iters = args.epochs * num_updates_per_epoch if args.num_averages > 0: assert args.num_averages <= args.epochs avg_start_iter = (args.epochs - args.num_average) * num_updates_per_epoch else: avg_start_iter = -1 # Here, we are manually setting up the scale to 1.0 because # in horovod, the scale can be the number of workers: # See the code here: https://github.com/horovod/horovod/blob/125115583b7029196e2ec530decd4209459d5479/horovod/mxnet/__init__.py#L141 # Since we will need to use the dynamic scaling in amp, we will manually call amp.unscale(). # A scale that is larger than 1.0 can be problematic in this case. trainer._scale = 1.0 if args.max_num_tokens > 0: const_scale = args.max_num_tokens else: const_scale = 100 train_start_time = time.time() for train_iter in range(total_train_iters): model.zero_grad() loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] for i in range(args.num_accumulated): loss_l = [] sample_data_l = next(train_multi_data_loader) for j, (sample_data, ctx) in enumerate(zip(sample_data_l, ctx_l)): src_token_ids, tgt_token_ids, src_valid_length,\ tgt_valid_length, sample_ids = sample_data src_token_ids = src_token_ids.as_in_ctx(ctx) tgt_token_ids = tgt_token_ids.as_in_ctx(ctx) src_valid_length = src_valid_length.as_in_ctx(ctx) tgt_valid_length = tgt_valid_length.as_in_ctx(ctx) src_wc, tgt_wc, bs = src_valid_length.sum(), \ tgt_valid_length.sum(), src_token_ids.shape[0] log_wc_l[j] += src_wc + tgt_wc log_tgt_wc_l[j] += tgt_wc token_count = (tgt_valid_length - 1).sum() loss_denom_l[j] += token_count / const_scale log_avg_loss_denom_l[j] += token_count / const_scale with mx.autograd.record(): if model.layout == 'NT': tgt_pred = model(src_token_ids, src_valid_length, tgt_token_ids[:, :-1], tgt_valid_length - 1) tgt_labels = tgt_token_ids[:, 1:] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=1) loss = loss.sum() / const_scale loss_l.append(loss) elif model.layout == 'TN': tgt_pred = model(src_token_ids.T, src_valid_length, tgt_token_ids.T[:-1, :], tgt_valid_length - 1) tgt_labels = tgt_token_ids.T[1:, :] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=0) loss = loss.sum() / const_scale loss_l.append(loss) log_avg_loss_l[j] += loss if use_amp: with mx.autograd.record(): with amp.scale_loss(loss_l, trainer) as amp_loss_l: for loss in amp_loss_l: loss.backward() else: with mx.autograd.record(): for loss in loss_l: loss.backward() # Print the total number of parameters if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters(param_dict) logging.info( 'Total Number of Parameters (not-fixed/fixed): {}/{}'.format( num_params, num_fixed_params)) # All-Reduce the gradient trainer.allreduce_grads() if args.comm_backend == 'horovod': # All-Reduce the loss denominator assert len(loss_denom_l) == 1 loss_denom = hvd.allreduce(loss_denom_l[0], average=False).asnumpy() else: loss_denom = sum([ele.asnumpy() for ele in loss_denom_l]) if use_amp: # We need to first unscale the gradient and then perform allreduce. grad_scale = trainer.amp_loss_scale * loss_denom else: grad_scale = loss_denom if args.max_grad_norm is not None: total_norm, ratio, is_finite\ = clip_grad_global_norm(params, args.max_grad_norm * grad_scale) total_norm = total_norm / grad_scale else: total_norm = grad_global_norm(params) total_norm = total_norm / grad_scale log_avg_grad_norm += total_norm log_iter_num += 1 trainer.update(loss_denom, ignore_stale_grad=True) if avg_start_iter > 0 and train_iter >= avg_start_iter: model_averager.step() if ((train_iter + 1) % args.log_interval == 0 or train_iter + 1 == total_train_iters): if args.comm_backend == 'horovod': # Use allreduce to get the total number of tokens and loss log_wc = hvd.allreduce(log_wc_l[0], average=False).asnumpy() log_tgt_wc = hvd.allreduce(log_tgt_wc_l[0], average=False).asnumpy() log_avg_loss = hvd.allreduce(log_avg_loss_l[0] / log_avg_loss_denom_l[0], average=True) log_avg_loss = log_avg_loss.asnumpy() else: log_wc = sum([ele.asnumpy() for ele in log_wc_l]) log_tgt_wc = sum([ele.asnumpy() for ele in log_tgt_wc_l]) log_avg_loss =\ sum([log_avg_loss_l[i].asnumpy() / log_avg_loss_denom_l[i].asnumpy() for i in range(len(log_avg_loss_l))]) / len(log_avg_loss_l) log_avg_grad_norm = log_avg_grad_norm / log_iter_num log_end_time = time.time() wps = log_wc / (log_end_time - log_start_time) epoch_id = train_iter // num_updates_per_epoch logging.info( '[Epoch {} Iter {}/{}, Overall {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, total wc={:.2f}K, wpb={:.2f}K,' ' LR={}, gnorm={:.4f}, ETA={:.2f}h'.format( epoch_id, train_iter % num_updates_per_epoch + 1, num_updates_per_epoch, train_iter + 1, total_train_iters, log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, log_tgt_wc / 1000 / log_iter_num, trainer.learning_rate, log_avg_grad_norm, (log_end_time - train_start_time) / (train_iter + 1) * (total_train_iters - train_iter - 1) / 3600)) if local_rank == 0: writer.add_scalar('throughput_wps', wps, train_iter) writer.add_scalar('train_loss', log_avg_loss, train_iter) writer.add_scalar('lr', trainer.learning_rate, train_iter) writer.add_scalar('grad_norm', log_avg_grad_norm, train_iter) # Reinitialize the log variables log_start_time = time.time() log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_avg_grad_norm = 0 log_iter_num = 0 log_wc_l = [ mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l ] log_tgt_wc_l = [ mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l ] if (args.max_update > 0 and (train_iter + 1) % args.save_interval_update == 0) \ or ((train_iter + 1) % num_updates_per_epoch == 0) \ or train_iter + 1 == total_train_iters: epoch_id = (train_iter + 1) // num_updates_per_epoch if local_rank == 0: if args.max_update <= 0: model.save_parameters(os.path.join( args.save_dir, 'epoch{}.params'.format(epoch_id)), deduplicate=True) else: model.save_parameters(os.path.join( args.save_dir, 'iter{}.params'.format(train_iter + 1)), deduplicate=True) avg_val_loss, ntokens, pred_sentences, pred_lengths, sentence_ids\ = validation(model, val_data_loader, inference_model, beam_search_sampler, tgt_tokenizer, ctx_l) if args.comm_backend == 'horovod': flatten_pred_sentences = np.concatenate(pred_sentences, axis=0) all_val_loss = hvd.allgather( mx.np.array([avg_val_loss * ntokens], dtype=np.float32, ctx=ctx_l[0])) all_ntokens = hvd.allgather( mx.np.array([ntokens], dtype=np.int64, ctx=ctx_l[0])) flatten_pred_sentences = hvd.allgather( mx.np.array(flatten_pred_sentences, dtype=np.int32, ctx=ctx_l[0])) pred_lengths = hvd.allgather( mx.np.array(pred_lengths, dtype=np.int64, ctx=ctx_l[0])) sentence_ids = hvd.allgather( mx.np.array(sentence_ids, dtype=np.int64, ctx=ctx_l[0])) avg_val_loss = all_val_loss.asnumpy().sum( ) / all_ntokens.asnumpy().sum() flatten_pred_sentences = flatten_pred_sentences.asnumpy() pred_lengths = pred_lengths.asnumpy() sentence_ids = sentence_ids.asnumpy() pred_sentences = [None for _ in range(len(sentence_ids))] ptr = 0 assert sentence_ids.min() == 0 and sentence_ids.max( ) == len(sentence_ids) - 1 for sentence_id, length in zip(sentence_ids, pred_lengths): pred_sentences[sentence_id] = flatten_pred_sentences[ptr:( ptr + length)] ptr += length if local_rank == 0: # Perform detokenization pred_sentences_bpe_decode = [] pred_sentences_raw = [] for sentence in pred_sentences: bpe_decode_sentence = tgt_tokenizer.decode( sentence.tolist()) raw_sentence = base_tgt_tokenizer.decode( bpe_decode_sentence.split()) pred_sentences_bpe_decode.append(bpe_decode_sentence) pred_sentences_raw.append(raw_sentence) detok_sacrebleu_out = sacrebleu.corpus_bleu( sys_stream=pred_sentences_bpe_decode, ref_streams=[tgt_detok_sentences]) raw_sacrebleu_out = sacrebleu.corpus_bleu( sys_stream=pred_sentences_raw, ref_streams=[tgt_raw_sentences]) with open( os.path.join(args.save_dir, f'epoch{epoch_id}_dev_prediction.txt'), 'w') as of: for line in pred_sentences_raw: of.write(line + '\n') logging.info( '[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, ' 'SacreBlEU={}, Detok SacreBLUE={}'.format( epoch_id, train_iter, total_train_iters, avg_val_loss, np.exp(avg_val_loss), raw_sacrebleu_out.score, detok_sacrebleu_out.score)) writer.add_scalar('valid_loss', avg_val_loss, train_iter) writer.add_scalar('valid_bleu', raw_sacrebleu_out.score, train_iter) if args.num_averages > 0: model_averager.copy_back( param_dict) # TODO(sxjscience) Rewrite using update model.save_parameters(os.path.join(args.save_dir, 'average.params'), deduplicate=True)
def get_network(model_name, ctx_l, dropout=0.1, checkpoint_path=None, backbone_path=None, dtype='float32'): """ Get the network that fine-tune the Question Answering Task Parameters ---------- model_name : str The model name of the backbone model ctx_l : Context list of training device like [mx.gpu(0), mx.gpu(1)] dropout : float Dropout probability of the task specified layer checkpoint_path: str Path to a Fine-tuned checkpoint backbone_path: str Path to the backbone model to be loaded in qa_net Returns ------- cfg tokenizer qa_net use_segmentation """ # Create the network use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name Model, cfg, tokenizer, download_params_path, _ = \ get_backbone(model_name, load_backbone=not backbone_path) backbone = Model.from_cfg(cfg, use_pooler=False, dtype=dtype) # Load local backbone parameters if backbone_path provided. # Otherwise, download backbone parameters from gluon zoo. backbone_params_path = backbone_path if backbone_path else download_params_path if checkpoint_path is None: backbone.load_parameters(backbone_params_path, ignore_extra=True, ctx=ctx_l, cast_dtype=True) num_params, num_fixed_params\ = count_parameters(deduplicate_param_dict(backbone.collect_params())) logging.info( 'Loading Backbone Model from {}, with total/fixd parameters={}/{}'. format(backbone_params_path, num_params, num_fixed_params)) qa_net = ModelForQAConditionalV1(backbone=backbone, dropout_prob=dropout, use_segmentation=use_segmentation, weight_initializer=TruncNorm(stdev=0.02)) if checkpoint_path is None: # Ignore the UserWarning during initialization, # There is no need to re-initialize the parameters of backbone qa_net.initialize(ctx=ctx_l) else: qa_net.load_parameters(checkpoint_path, ctx=ctx_l, cast_dtype=True) qa_net.hybridize() return cfg, tokenizer, qa_net, use_segmentation