def get_dataloader(dataset): """create data loader based on the dataset chunk""" t0 = time.time() lengths = dataset.get_field('valid_lengths') logging.debug('Num samples = %d', len(lengths)) # A batch includes: input_id, masked_id, masked_position, masked_weight, # next_sentence_label, segment_id, valid_length batchify_fn = Tuple(Pad(), Pad(), Pad(), Pad(), Stack(), Pad(), Stack()) if args.by_token: # sharded data loader sampler = nlp.data.FixedBucketSampler( lengths=lengths, # batch_size per shard batch_size=batch_size, num_buckets=args.num_buckets, shuffle=is_train, use_average_length=True, num_shards=num_ctxes) dataloader = nlp.data.ShardedDataLoader(dataset, batch_sampler=sampler, batchify_fn=batchify_fn, num_workers=num_ctxes) logging.debug('Batch Sampler:\n%s', sampler.stats()) else: sampler = FixedBucketSampler(lengths, batch_size=batch_size * num_ctxes, num_buckets=args.num_buckets, ratio=0, shuffle=is_train) dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, batchify_fn=batchify_fn, num_workers=1) logging.debug('Batch Sampler:\n%s', sampler.stats()) t1 = time.time() logging.debug('Dataloader creation cost = %.2f s', t1 - t0) return dataloader
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, { 'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9 }) train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32')) test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32'), btf.Stack()) target_val_lengths = list(map(lambda x: x[-1], data_val_lengths)) target_test_lengths = list(map(lambda x: x[-1], data_test_lengths)) if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths, batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True, num_shards=len(ctx), bucket_scheme=bucket_scheme) logging.info('Train Batch Sampler:\n{}'.format( train_batch_sampler.stats())) train_data_loader = ShardedDataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=train_batchify_fn, num_workers=8) val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=True, bucket_scheme=bucket_scheme) logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats())) val_data_loader = DataLoader(data_val, batch_sampler=val_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=True, bucket_scheme=bucket_scheme) logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats())) test_data_loader = DataLoader(data_test, batch_sampler=test_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) if args.bleu == 'tweaked': bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') split_compound_word = bpe tokenized = True elif args.bleu == '13a' or args.bleu == 'intl': bpe = False split_compound_word = False tokenized = False else: raise NotImplementedError best_valid_bleu = 0.0 step_num = 0 warmup_steps = args.warmup_steps grad_interval = args.num_accumulated model.collect_params().setattr('grad_req', 'add') average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) average_param_dict = None model.collect_params().zero_grad() for epoch_id in range(args.epochs): log_avg_loss = 0 log_wc = 0 loss_denom = 0 step_loss = 0 log_start_time = time.time() for batch_id, seqs \ in enumerate(train_data_loader): if batch_id % grad_interval == 0: step_num += 1 new_lr = args.lr / math.sqrt(args.num_units) \ * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5)) trainer.set_learning_rate(new_lr) src_wc, tgt_wc, bs = np.sum( [(shard[2].sum(), shard[3].sum(), shard[0].shape[0]) for shard in seqs], axis=0) src_wc = src_wc.asscalar() tgt_wc = tgt_wc.asscalar() loss_denom += tgt_wc - bs seqs = [[seq.as_in_context(context) for seq in shard] for context, shard in zip(ctx, seqs)] Ls = [] with mx.autograd.record(): for src_seq, tgt_seq, src_valid_length, tgt_valid_length in seqs: out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) smoothed_label = label_smoothing(tgt_seq[:, 1:]) ls = loss_function(out, smoothed_label, tgt_valid_length - 1).sum() Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size / 100.0) for L in Ls: L.backward() if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: if average_param_dict is None: average_param_dict = { k: v.data(ctx[0]).copy() for k, v in model.collect_params().items() } trainer.step(float(loss_denom) / args.batch_size / 100.0) param_dict = model.collect_params() param_dict.zero_grad() if step_num > average_start: alpha = 1. / max(1, step_num - average_start) for name, average_param in average_param_dict.items(): average_param[:] += alpha * ( param_dict[name].data(ctx[0]) - average_param) step_loss += sum([L.asscalar() for L in Ls]) if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0 loss_denom = 0 step_loss = 0 log_wc += src_wc + tgt_wc if (batch_id + 1) % (args.log_interval * grad_interval) == 0: wps = log_wc / (time.time() - log_start_time) logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K'.format( epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_wc = 0 mx.nd.waitall() valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu( [val_tgt_sentences], valid_translation_out, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) logging.info( '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu( [test_tgt_sentences], test_translation_out, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) logging.info( '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences( valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) write_sentences( test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_parameters(save_path) save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)) model.save_parameters(save_path) save_path = os.path.join(args.save_dir, 'average.params') mx.nd.save(save_path, average_param_dict) if args.average_checkpoint: for j in range(args.num_averages): params = mx.nd.load( os.path.join(args.save_dir, 'epoch{:d}.params'.format(args.epochs - j - 1))) alpha = 1. / (j + 1) for k, v in model._collect_params_with_prefix().items(): for c in ctx: v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c)) save_path = os.path.join( args.save_dir, 'average_checkpoint_{}.params'.format(args.num_averages)) model.save_parameters(save_path) elif args.average_start > 0: for k, v in model.collect_params().items(): v.set_data(average_param_dict[k]) save_path = os.path.join(args.save_dir, 'average.params') model.save_parameters(save_path) else: model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'), ctx) valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu( [val_tgt_sentences], valid_translation_out, tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, split_compound_word=split_compound_word) logging.info( 'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'. format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu( [test_tgt_sentences], test_translation_out, tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, split_compound_word=split_compound_word) logging.info( 'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences(valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) write_sentences(test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr}) train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32')) test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32'), btf.Stack()) if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths, batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, bucket_scheme=bucket_scheme) logging.info('Train Batch Sampler:\n{}'.format( train_batch_sampler.stats())) train_data_loader = DataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=train_batchify_fn, num_workers=8) val_batch_sampler = FixedBucketSampler(lengths=data_val_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False) logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats())) val_data_loader = DataLoader(data_val, batch_sampler=val_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) test_batch_sampler = FixedBucketSampler(lengths=data_test_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False) logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats())) test_data_loader = DataLoader(data_test, batch_sampler=test_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) best_valid_bleu = 0.0 for epoch_id in range(args.epochs): log_avg_loss = 0 log_avg_gnorm = 0 log_wc = 0 log_start_time = time.time() for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\ in enumerate(train_data_loader): # logging.info(src_seq.context) Context suddenly becomes GPU. src_seq = src_seq.as_in_context(ctx) tgt_seq = tgt_seq.as_in_context(ctx) src_valid_length = src_valid_length.as_in_context(ctx) tgt_valid_length = tgt_valid_length.as_in_context(ctx) with mx.autograd.record(): out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean() loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean() loss.backward() grads = [p.grad(ctx) for p in model.collect_params().values()] gnorm = gluon.utils.clip_global_norm(grads, args.clip) trainer.step(1) src_wc = src_valid_length.sum().asscalar() tgt_wc = (tgt_valid_length - 1).sum().asscalar() step_loss = loss.asscalar() log_avg_loss += step_loss log_avg_gnorm += gnorm log_wc += src_wc + tgt_wc if (batch_id + 1) % args.log_interval == 0: wps = log_wc / (time.time() - log_start_time) logging.info( '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K'.format( epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), log_avg_gnorm / args.log_interval, wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_avg_gnorm = 0 log_wc = 0 valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info( '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info( '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences( valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) write_sentences( test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_params(save_path) if epoch_id + 1 >= (args.epochs * 2) // 3: new_lr = trainer.learning_rate * args.lr_update_factor logging.info('Learning rate change to {}'.format(new_lr)) trainer.set_learning_rate(new_lr) model.load_params(os.path.join(args.save_dir, 'valid_best.params')) valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info( 'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'. format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info( 'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences(valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) write_sentences(test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, { 'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9 }) train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(), btf.Stack()) test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(), btf.Stack(), btf.Stack()) target_val_lengths = list(map(lambda x: x[-1], data_val_lengths)) target_test_lengths = list(map(lambda x: x[-1], data_test_lengths)) train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths, batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True) logging.info('Train Batch Sampler:\n{}'.format( train_batch_sampler.stats())) train_data_loader = DataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=train_batchify_fn, num_workers=8) val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=True) logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats())) val_data_loader = DataLoader(data_val, batch_sampler=val_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=True) logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats())) test_data_loader = DataLoader(data_test, batch_sampler=test_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) best_valid_bleu = 0.0 step_num = 0 warmup_steps = args.warmup_steps grad_interval = args.num_accumulated model.collect_params().setattr('grad_req', 'add') average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) average_param_dict = None model.collect_params().zero_grad() for epoch_id in range(args.epochs): log_avg_loss = 0 log_wc = 0 loss_denom = 0 step_loss = 0 log_start_time = time.time() for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length) \ in enumerate(train_data_loader): if batch_id % grad_interval == 0: step_num += 1 new_lr = args.lr / math.sqrt(args.num_units) \ * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5)) trainer.set_learning_rate(new_lr) # logging.info(src_seq.context) Context suddenly becomes GPU. src_wc = src_valid_length.sum().asscalar() tgt_wc = tgt_valid_length.sum().asscalar() loss_denom += tgt_wc - tgt_valid_length.shape[0] if src_seq.shape[0] > len(ctx): src_seq_list, tgt_seq_list, src_valid_length_list, tgt_valid_length_list \ = [gluon.utils.split_and_load(seq, ctx, batch_axis=0, even_split=False) for seq in [src_seq, tgt_seq, src_valid_length, tgt_valid_length]] else: src_seq_list = [src_seq.as_in_context(ctx[0])] tgt_seq_list = [tgt_seq.as_in_context(ctx[0])] src_valid_length_list = [ src_valid_length.as_in_context(ctx[0]) ] tgt_valid_length_list = [ tgt_valid_length.as_in_context(ctx[0]) ] Ls = [] with mx.autograd.record(): for src_seq, tgt_seq, src_valid_length, tgt_valid_length \ in zip(src_seq_list, tgt_seq_list, src_valid_length_list, tgt_valid_length_list): out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) smoothed_label = label_smoothing(tgt_seq[:, 1:]) ls = loss_function(out, smoothed_label, tgt_valid_length - 1).sum() Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size) for L in Ls: L.backward() if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: if average_param_dict is None: average_param_dict = { k: v.data(ctx[0]).copy() for k, v in model.collect_params().items() } trainer.step(float(loss_denom) / args.batch_size) param_dict = model.collect_params() param_dict.zero_grad() if step_num > average_start: alpha = 1. / max(1, step_num - average_start) for name, average_param in average_param_dict.items(): average_param[:] += alpha * ( param_dict[name].data(ctx[0]) - average_param) step_loss += sum([L.asscalar() for L in Ls]) if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: log_avg_loss += step_loss / loss_denom * args.batch_size loss_denom = 0 step_loss = 0 log_wc += src_wc + tgt_wc if (batch_id + 1) % (args.log_interval * grad_interval) == 0: wps = log_wc / (time.time() - log_start_time) logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K'.format( epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_wc = 0 mx.nd.waitall() valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, bpe=True, split_compound_word=True) logging.info( '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, bpe=True, split_compound_word=True) logging.info( '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences( valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) write_sentences( test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_params(save_path) save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)) model.save_params(save_path) save_path = os.path.join(args.save_dir, 'average.params') mx.nd.save(save_path, average_param_dict) if args.average_checkpoint: for j in range(args.num_averages): params = mx.nd.load( os.path.join(args.save_dir, 'epoch{:d}.params'.format(args.epochs - j - 1))) alpha = 1. / (j + 1) for k, v in model._collect_params_with_prefix().items(): for c in ctx: v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c)) elif args.average_start > 0: for k, v in model.collect_params().items(): v.set_data(average_param_dict[k]) else: model.load_params(os.path.join(args.save_dir, 'valid_best.params'), ctx) valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, bpe=True, split_compound_word=True) logging.info( 'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'. format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, bpe=True, split_compound_word=True) logging.info( 'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences(valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) write_sentences(test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr}) train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(), btf.Stack()) test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(), btf.Stack(), btf.Stack()) train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths, batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True) logging.info('Train Batch Sampler:\n{}'.format(train_batch_sampler.stats())) train_data_loader = DataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=train_batchify_fn, num_workers=8) val_batch_sampler = FixedBucketSampler(lengths=data_val_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False) logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats())) val_data_loader = DataLoader(data_val, batch_sampler=val_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) test_batch_sampler = FixedBucketSampler(lengths=data_test_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False) logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats())) test_data_loader = DataLoader(data_test, batch_sampler=test_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) best_valid_bleu = 0.0 for epoch_id in range(args.epochs): log_avg_loss = 0 log_avg_gnorm = 0 log_wc = 0 log_start_time = time.time() for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\ in enumerate(train_data_loader): # logging.info(src_seq.context) Context suddenly becomes GPU. src_seq = src_seq.as_in_context(ctx) tgt_seq = tgt_seq.as_in_context(ctx) src_valid_length = src_valid_length.as_in_context(ctx) tgt_valid_length = tgt_valid_length.as_in_context(ctx) with mx.autograd.record(): out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean() loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean() loss.backward() grads = [p.grad(ctx) for p in model.collect_params().values()] gnorm = gluon.utils.clip_global_norm(grads, args.clip) trainer.step(1) src_wc = src_valid_length.sum().asscalar() tgt_wc = (tgt_valid_length - 1).sum().asscalar() step_loss = loss.asscalar() log_avg_loss += step_loss log_avg_gnorm += gnorm log_wc += src_wc + tgt_wc if (batch_id + 1) % args.log_interval == 0: wps = log_wc / (time.time() - log_start_time) logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K' .format(epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), log_avg_gnorm / args.log_interval, wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_avg_gnorm = 0 log_wc = 0 valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences(valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) write_sentences(test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_params(save_path) else: new_lr = trainer.learning_rate * args.lr_update_factor logging.info('Learning rate change to {}'.format(new_lr)) trainer.set_learning_rate(new_lr) model.load_params(os.path.join(args.save_dir, 'valid_best.params')) valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info('Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info('Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' .format(test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences(valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) write_sentences(test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))