def test_average_sgd_tracker(): samples = [mx.np.random.normal(0, 1, (10, 3)) for _ in range(10)] no_moving_avg_param_l = [] with_moving_avg_param_l = [] moving_avg_param = None net_final_moving_avg_param = None for use_moving_avg in [False, True]: net = HybridSequential() net.add(nn.Dense(10), nn.Dense(3)) net.initialize(init=mx.init.One()) net.hybridize() trainer = mx.gluon.Trainer(net.collect_params(), 'adam') if use_moving_avg: model_averager = AverageSGDTracker(net.collect_params()) for sample in samples: out = sample ** 3 + sample with mx.autograd.record(): pred = net(sample) loss = ((out - pred) ** 2).mean() loss.backward() trainer.step(1.0) if use_moving_avg: model_averager.step() print(model_averager.average_params) if use_moving_avg: with_moving_avg_param_l.append({k: v.data().asnumpy() for k, v in net.collect_params().items()}) else: no_moving_avg_param_l.append({k: v.data().asnumpy() for k, v in net.collect_params().items()}) if use_moving_avg: model_averager.copy_back() moving_avg_param = {k: v.asnumpy() for k, v in model_averager.average_params.items()} net_final_moving_avg_param = {k: v.data().asnumpy() for k, v in net.collect_params().items()} # Match the recorded params calculated_moving_param = {k: np.zeros(v.shape) for k, v in no_moving_avg_param_l[0].items()} for step, (no_moving_avg_param, with_moving_avg_param) in enumerate(zip(no_moving_avg_param_l, with_moving_avg_param_l)): decay = 1.0 / (step + 1) assert len(no_moving_avg_param) == len(with_moving_avg_param) for k in with_moving_avg_param: assert_allclose(no_moving_avg_param[k], with_moving_avg_param[k]) calculated_moving_param[k] += decay * (with_moving_avg_param[k] - calculated_moving_param[k]) assert len(moving_avg_param) == len(net_final_moving_avg_param) for k in moving_avg_param: assert_allclose(moving_avg_param[k], calculated_moving_param[k], 1E-5, 1E-5) assert_allclose(moving_avg_param[k], net_final_moving_avg_param[k], 1E-5, 1E-5)
def train(args): store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) tgt_tokenizer = create_tokenizer(args.tgt_tokenizer, args.tgt_subword_model_path, args.tgt_vocab_path) src_vocab = src_tokenizer.vocab tgt_vocab = tgt_tokenizer.vocab train_src_data, train_tgt_data = load_dataset_with_cache( args.train_src_corpus, args.train_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache) dev_src_data, dev_tgt_data = load_dataset_with_cache( args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache) data_train = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data)) ]) data_val = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data)) ]) # Construct the model + loss function if args.cfg.endswith('.yml'): cfg = TransformerModel.get_cfg().clone_merge(args.cfg) else: cfg = TransformerModel.get_cfg(args.cfg) cfg.defrost() cfg.MODEL.src_vocab_size = len(src_vocab) cfg.MODEL.tgt_vocab_size = len(tgt_vocab) if args.fp16: raise NotImplementedError # cfg.MODEL.dtype = 'float16' cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() if local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) label_smooth_loss = LabelSmoothCrossEntropyLoss( num_labels=len(tgt_vocab), alpha=args.label_smooth_alpha, from_logits=False) label_smooth_loss.hybridize() rescale_loss = 100.0 if args.comm_backend == 'horovod': hvd.broadcast_parameters(model.collect_params(), root_rank=0) # Construct the trainer # TODO(sxjscience) Support AMP if args.lr is None: base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt( args.warmup_steps) else: base_lr = args.lr lr_scheduler = InverseSquareRootScheduler( warmup_steps=args.warmup_steps, base_lr=base_lr, warmup_init_lr=args.warmup_init_lr) trainer_settings = (model.collect_params(), 'adam', { 'learning_rate': args.lr, 'beta1': 0.9, 'beta2': 0.98, 'epsilon': 1e-9, 'lr_scheduler': lr_scheduler }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(*trainer_settings) else: trainer = gluon.Trainer(*trainer_settings) # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler( lengths=[(ele[2], ele[3]) for ele in data_train], max_num_tokens=args.max_num_tokens, max_num_sentences=args.max_num_sentences, seed=args.seed, num_parts=num_parts, part_index=rank) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': raise NotImplementedError( 'FixedBucketSampler does not support horovod at present') if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError # TODO(sxjscience) Support auto-bucket-size tuning train_batch_sampler = FixedBucketSampler(lengths=[ (ele[2], ele[3]) for ele in data_train ], batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True, bucket_scheme=bucket_scheme, seed=args.seed) else: raise NotImplementedError if local_rank == 0: logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader( data_train, batch_sampler=train_batch_sampler, batchify_fn=batchify_fn, num_workers=0) val_data_loader = gluon.data.DataLoader(data_val, batch_size=args.val_batch_size, batchify_fn=batchify_fn, num_workers=0, shuffle=False) for v in model.collect_params().values(): if v.grad_req != 'null': v.grad_req = 'add' model.zero_grad() model_averager = AverageSGDTracker(model.collect_params()) log_start_time = time.time() num_params, num_fixed_params = None, None # TODO(sxjscience) Add a log metric class accum_count = 0 loss_denom = 0 n_train_iters = 0 log_wc = 0 log_avg_loss = 0.0 log_loss_denom = 0 epoch_id = 0 while (args.epochs < 0 or epoch_id < args.epochs ): # when args.epochs < 0, the model will keep training n_epoch_train_iters = 0 processed_batch_num = 0 train_multi_data_loader = grouper(train_data_loader, len(ctx_l)) is_last_batch = False sample_data_l = next(train_multi_data_loader) while not is_last_batch: processed_batch_num += len(sample_data_l) loss_l = [] for sample_data, ctx in zip(sample_data_l, ctx_l): if sample_data is None: continue src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data src_wc, tgt_wc, bs = src_valid_length.sum( ), tgt_valid_length.sum(), src_token_ids.shape[0] loss_denom += tgt_wc - bs log_loss_denom += tgt_wc - bs log_wc += src_wc + tgt_wc src_token_ids = src_token_ids.as_in_ctx(ctx) tgt_token_ids = tgt_token_ids.as_in_ctx(ctx) src_valid_length = src_valid_length.as_in_ctx(ctx) tgt_valid_length = tgt_valid_length.as_in_ctx(ctx) with mx.autograd.record(): tgt_pred = model(src_token_ids, src_valid_length, tgt_token_ids[:, :-1], tgt_valid_length - 1) tgt_labels = tgt_token_ids[:, 1:] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=1) loss_l.append(loss.sum() / rescale_loss) for l in loss_l: l.backward() accum_count += 1 try: sample_data_l = next(train_multi_data_loader) except StopIteration: is_last_batch = True if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters( model.collect_params()) logging.info( 'Total Number of Parameters (not-fixed/fixed): {}/{}'. format(num_params, num_fixed_params)) sum_loss = sum([l.as_in_ctx(mx.cpu()) for l in loss_l]) * rescale_loss log_avg_loss += sum_loss mx.npx.waitall() if accum_count == args.num_accumulated or is_last_batch: # Update the parameters n_train_iters += 1 n_epoch_train_iters += 1 trainer.step(loss_denom.asnumpy() / rescale_loss) accum_count = 0 loss_denom = 0 model.zero_grad() if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() if local_rank == 0 and \ (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() wps = log_wc / (log_end_time - log_start_time) log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy() logging.info( '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K, LR={}'.format( epoch_id, processed_batch_num * num_parts, len(train_data_loader), log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, trainer.learning_rate)) log_start_time = time.time() log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 if local_rank == 0 and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): model.save_parameters(os.path.join( args.save_dir, 'update{:d}.params'.format( n_train_iters // args.save_interval_update)), deduplicate=True) if args.max_update > 0 and n_train_iters >= args.max_update: break if local_rank == 0 and args.epochs > 0: model.save_parameters(os.path.join( args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) avg_valid_loss = validation(model, val_data_loader, ctx_l) logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'.format( epoch_id, avg_valid_loss, np.exp(avg_valid_loss))) if args.max_update > 0 and n_train_iters >= args.max_update: break epoch_id += 1 if args.num_averages > 0: model_averager.copy_back( model.collect_params()) # TODO(sxjscience) Rewrite using update model.save_parameters(os.path.join(args.save_dir, 'average.params'), deduplicate=True)