def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr}) train_data_loader, val_data_loader, test_data_loader \ = dataprocessor.make_dataloader(data_train, data_val, data_test, args) best_valid_bleu = 0.0 for epoch_id in range(args.epochs): log_loss = 0 log_denom = 0 log_avg_gnorm = 0 log_wc = 0 log_start_time = time.time() for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\ in enumerate(train_data_loader): # logging.info(src_seq.context) Context suddenly becomes GPU. src_seq = src_seq.as_in_context(ctx) tgt_seq = tgt_seq.as_in_context(ctx) src_valid_length = src_valid_length.as_in_context(ctx) tgt_valid_length = tgt_valid_length.as_in_context(ctx) with mx.autograd.record(): out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean() loss = loss * (tgt_seq.shape[1] - 1) log_loss += loss * tgt_seq.shape[0] log_denom += (tgt_valid_length - 1).sum() loss = loss / (tgt_valid_length - 1).mean() loss.backward() grads = [p.grad(ctx) for p in model.collect_params().values()] gnorm = gluon.utils.clip_global_norm(grads, args.clip) trainer.step(1) src_wc = src_valid_length.sum().asscalar() tgt_wc = (tgt_valid_length - 1).sum().asscalar() log_loss = log_loss.asscalar() log_denom = log_denom.asscalar() log_avg_gnorm += gnorm log_wc += src_wc + tgt_wc if (batch_id + 1) % args.log_interval == 0: wps = log_wc / (time.time() - log_start_time) logging.info( '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K'.format( epoch_id, batch_id + 1, len(train_data_loader), log_loss / log_denom, np.exp(log_loss / log_denom), log_avg_gnorm / args.log_interval, wps / 1000, log_wc / 1000)) log_start_time = time.time() log_loss = 0 log_denom = 0 log_avg_gnorm = 0 log_wc = 0 valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info( '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) dataprocessor.write_sentences( valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) if args.validate_on_test_data: test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info( '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) dataprocessor.write_sentences( test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_parameters(save_path) if epoch_id + 1 >= (args.epochs * 2) // 3: new_lr = trainer.learning_rate * args.lr_update_factor logging.info('Learning rate change to {}'.format(new_lr)) trainer.set_learning_rate(new_lr) if os.path.exists(os.path.join(args.save_dir, 'valid_best.params')): model.load_parameters(os.path.join(args.save_dir, 'valid_best.params')) valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info( 'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'. format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info( 'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(test_loss, np.exp(test_loss), test_bleu_score * 100)) dataprocessor.write_sentences( valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) dataprocessor.write_sentences( test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9}) train_data_loader, val_data_loader, test_data_loader \ = dataprocessor.make_dataloader(data_train, data_val, data_test, args, use_average_length=True, num_shards=len(ctx)) if args.bleu == 'tweaked': bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') split_compound_word = bpe tokenized = True elif args.bleu == '13a' or args.bleu == 'intl': bpe = False split_compound_word = False tokenized = False else: raise NotImplementedError best_valid_bleu = 0.0 step_num = 0 warmup_steps = args.warmup_steps grad_interval = args.num_accumulated model.collect_params().setattr('grad_req', 'add') average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) average_param_dict = None model.collect_params().zero_grad() for epoch_id in range(args.epochs): log_avg_loss = 0 log_wc = 0 loss_denom = 0 step_loss = 0 log_start_time = time.time() for batch_id, seqs \ in enumerate(train_data_loader): if batch_id % grad_interval == 0: step_num += 1 new_lr = args.lr / math.sqrt(args.num_units) \ * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5)) trainer.set_learning_rate(new_lr) src_wc, tgt_wc, bs = np.sum([(shard[2].sum(), shard[3].sum(), shard[0].shape[0]) for shard in seqs], axis=0) src_wc = src_wc.asscalar() tgt_wc = tgt_wc.asscalar() loss_denom += tgt_wc - bs seqs = [[seq.as_in_context(context) for seq in shard] for context, shard in zip(ctx, seqs)] Ls = [] with mx.autograd.record(): for src_seq, tgt_seq, src_valid_length, tgt_valid_length in seqs: out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) smoothed_label = label_smoothing(tgt_seq[:, 1:]) ls = loss_function(out, smoothed_label, tgt_valid_length - 1).sum() Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size / 100.0) for L in Ls: L.backward() if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: if average_param_dict is None: average_param_dict = {k: v.data(ctx[0]).copy() for k, v in model.collect_params().items()} trainer.step(float(loss_denom) / args.batch_size / 100.0) param_dict = model.collect_params() param_dict.zero_grad() if step_num > average_start: alpha = 1. / max(1, step_num - average_start) for name, average_param in average_param_dict.items(): average_param[:] += alpha * (param_dict[name].data(ctx[0]) - average_param) step_loss += sum([L.asscalar() for L in Ls]) if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0 loss_denom = 0 step_loss = 0 log_wc += src_wc + tgt_wc if (batch_id + 1) % (args.log_interval * grad_interval) == 0: wps = log_wc / (time.time() - log_start_time) logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K' .format(epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_wc = 0 mx.nd.waitall() valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) dataprocessor.write_sentences(valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) dataprocessor.write_sentences(test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_parameters(save_path) save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)) model.save_parameters(save_path) save_path = os.path.join(args.save_dir, 'average.params') mx.nd.save(save_path, average_param_dict) if args.average_checkpoint: for j in range(args.num_averages): params = mx.nd.load(os.path.join(args.save_dir, 'epoch{:d}.params'.format(args.epochs - j - 1))) alpha = 1. / (j + 1) for k, v in model._collect_params_with_prefix().items(): for c in ctx: v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c)) save_path = os.path.join(args.save_dir, 'average_checkpoint_{}.params'.format(args.num_averages)) model.save_parameters(save_path) elif args.average_start > 0: for k, v in model.collect_params().items(): v.set_data(average_param_dict[k]) save_path = os.path.join(args.save_dir, 'average.params') model.save_parameters(save_path) else: model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'), ctx) valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, split_compound_word=split_compound_word) logging.info('Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, split_compound_word=split_compound_word) logging.info('Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' .format(test_loss, np.exp(test_loss), test_bleu_score * 100)) dataprocessor.write_sentences(valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) dataprocessor.write_sentences(test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, { 'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9 }) train_data_loader, val_data_loader, test_data_loader \ = dataprocessor.make_dataloader(data_train, data_val, data_test, args, use_average_length=True, num_shards=len(ctx)) if args.bleu == 'tweaked': bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') split_compound_word = bpe tokenized = True elif args.bleu == '13a' or args.bleu == 'intl': bpe = False split_compound_word = False tokenized = False else: raise NotImplementedError best_valid_bleu = 0.0 step_num = 0 warmup_steps = args.warmup_steps grad_interval = args.num_accumulated model.collect_params().setattr('grad_req', 'add') average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) average_param_dict = None model.collect_params().zero_grad() parallel = Parallel(num_ctxs, parallel_model) for epoch_id in range(args.epochs): log_avg_loss = 0 log_wc = 0 loss_denom = 0 step_loss = 0 log_start_time = time.time() for batch_id, seqs \ in enumerate(train_data_loader): if batch_id % grad_interval == 0: step_num += 1 new_lr = args.lr / math.sqrt(args.num_units) \ * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5)) trainer.set_learning_rate(new_lr) src_wc, tgt_wc, bs = np.sum( [(shard[2].sum(), shard[3].sum(), shard[0].shape[0]) for shard in seqs], axis=0) seqs = [[seq.as_in_context(context) for seq in shard] for context, shard in zip(ctx, seqs)] Ls = [] for seq in seqs: parallel.put((seq, args.batch_size)) Ls = [parallel.get() for _ in range(len(ctx))] src_wc = src_wc.asscalar() tgt_wc = tgt_wc.asscalar() loss_denom += tgt_wc - bs if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: if average_param_dict is None: average_param_dict = { k: v.data(ctx[0]).copy() for k, v in model.collect_params().items() } trainer.step(float(loss_denom) / args.batch_size / 100.0) param_dict = model.collect_params() param_dict.zero_grad() if step_num > average_start: alpha = 1. / max(1, step_num - average_start) for name, average_param in average_param_dict.items(): average_param[:] += alpha * ( param_dict[name].data(ctx[0]) - average_param) step_loss += sum([L.asscalar() for L in Ls]) if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0 loss_denom = 0 step_loss = 0 log_wc += src_wc + tgt_wc if (batch_id + 1) % (args.log_interval * grad_interval) == 0: wps = log_wc / (time.time() - log_start_time) logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K'.format( epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_wc = 0 mx.nd.waitall() valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu( [val_tgt_sentences], valid_translation_out, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) logging.info( '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu( [test_tgt_sentences], test_translation_out, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) logging.info( '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) dataprocessor.write_sentences( valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) dataprocessor.write_sentences( test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_parameters(save_path) save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)) model.save_parameters(save_path) save_path = os.path.join(args.save_dir, 'average.params') mx.nd.save(save_path, average_param_dict) if args.average_checkpoint: for j in range(args.num_averages): params = mx.nd.load( os.path.join(args.save_dir, 'epoch{:d}.params'.format(args.epochs - j - 1))) alpha = 1. / (j + 1) for k, v in model._collect_params_with_prefix().items(): for c in ctx: v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c)) save_path = os.path.join( args.save_dir, 'average_checkpoint_{}.params'.format(args.num_averages)) model.save_parameters(save_path) elif args.average_start > 0: for k, v in model.collect_params().items(): v.set_data(average_param_dict[k]) save_path = os.path.join(args.save_dir, 'average.params') model.save_parameters(save_path) else: model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'), ctx) valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu( [val_tgt_sentences], valid_translation_out, tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, split_compound_word=split_compound_word) logging.info( 'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'. format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu( [test_tgt_sentences], test_translation_out, tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, split_compound_word=split_compound_word) logging.info( 'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(test_loss, np.exp(test_loss), test_bleu_score * 100)) dataprocessor.write_sentences( valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) dataprocessor.write_sentences( test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr}) train_data_loader, val_data_loader, test_data_loader \ = dataprocessor.make_dataloader(data_train, data_val, data_test, args) best_valid_bleu = 0.0 for epoch_id in range(args.epochs): log_avg_loss = 0 log_avg_gnorm = 0 log_wc = 0 log_start_time = time.time() for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\ in enumerate(train_data_loader): # logging.info(src_seq.context) Context suddenly becomes GPU. src_seq = src_seq.as_in_context(ctx) tgt_seq = tgt_seq.as_in_context(ctx) src_valid_length = src_valid_length.as_in_context(ctx) tgt_valid_length = tgt_valid_length.as_in_context(ctx) with mx.autograd.record(): out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean() loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean() loss.backward() grads = [p.grad(ctx) for p in model.collect_params().values()] gnorm = gluon.utils.clip_global_norm(grads, args.clip) trainer.step(1) src_wc = src_valid_length.sum().asscalar() tgt_wc = (tgt_valid_length - 1).sum().asscalar() step_loss = loss.asscalar() log_avg_loss += step_loss log_avg_gnorm += gnorm log_wc += src_wc + tgt_wc if (batch_id + 1) % args.log_interval == 0: wps = log_wc / (time.time() - log_start_time) logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K' .format(epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), log_avg_gnorm / args.log_interval, wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_avg_gnorm = 0 log_wc = 0 valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) dataprocessor.write_sentences(valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) dataprocessor.write_sentences(test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_parameters(save_path) if epoch_id + 1 >= (args.epochs * 2) // 3: new_lr = trainer.learning_rate * args.lr_update_factor logging.info('Learning rate change to {}'.format(new_lr)) trainer.set_learning_rate(new_lr) if os.path.exists(os.path.join(args.save_dir, 'valid_best.params')): model.load_parameters(os.path.join(args.save_dir, 'valid_best.params')) valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info('Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info('Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' .format(test_loss, np.exp(test_loss), test_bleu_score * 100)) dataprocessor.write_sentences(valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) dataprocessor.write_sentences(test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))