def eval(model, criterion, valid_data): stats = Loss.Statistics() model.eval() loss = Loss.LossCompute(model.generator, criterion) for src, tgt in valid_data: src, tgt, src_lengths = prepare_data(src, tgt, True) outputs = model(src, tgt[:-1], src_lengths) gen_state = loss.make_loss_batch(outputs, tgt[1:]) _, batch_stats = loss.compute_loss(**gen_state) stats.update(batch_stats) model.train() return stats
def train(opt): print('| build data iterators') train = TextIterator(*opt.datasets, *opt.dicts, src_vocab_size=opt.src_vocab_size, tgt_vocab_size=opt.tgt_vocab_size, batch_size=opt.batch_size, max_seq_length=opt.max_seq_length) valid = TextIterator(*opt.valid_datasets, *opt.dicts, src_vocab_size=opt.src_vocab_size, tgt_vocab_size=opt.tgt_vocab_size, batch_size=opt.batch_size, max_seq_length=opt.max_seq_length) if opt.src_vocab_size < 0: opt.src_vocab_size = len(train.source_dict) if opt.tgt_vocab_size < 0: opt.tgt_vocab_size = len(train.target_dict) print('| vocabulary size. source = %d; target = %d' % (opt.src_vocab_size, opt.tgt_vocab_size)) dicts = [train.source_dict, train.target_dict] crit = Loss.nmt_criterion(opt.tgt_vocab_size, 0).to(device) if opt.train_from != '': print('| Load trained model!') checkpoint = torch.load(opt.train_from) model = models.make_base_model(opt, checkpoint) else: model = models.make_base_model(opt) init_uniform(model) model.to(device) if opt.encoder_type in ["sabrnn", "fabrnn"]: print('Add punctuation constrain!') model.encoder.punct(train.src_punct) print(model) model.dicts = dicts check_model_path() tally_parameters(model) optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate) scheduler = ReduceLROnPlateau(optimizer, 'min', factor=opt.learning_rate_decay, patience=0) uidx = 0 # number of updates estop = False min_lr = opt.learning_rate * math.pow(opt.learning_rate_decay, 5) best_bleu = -1 for eidx in range(1, opt.epochs + 1): closs = Loss.LossCompute(model.generator, crit) tot_loss = 0 total_stats = Loss.Statistics() report_stats = Loss.Statistics() for x, y in train: model.zero_grad() src, tgt, lengths_x = prepare_data(x, y) out = model(src, tgt[:-1], lengths_x) gen_state = closs.make_loss_batch(out, tgt[1:]) shard_size = opt.max_generator_batches batch_size = len(lengths_x) batch_stats = Loss.Statistics() for shard in Loss.shards(gen_state, shard_size): loss, stats = closs.compute_loss(**shard) loss.div(batch_size).backward() batch_stats.update(stats) tot_loss += loss.item() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm) optimizer.step() total_stats.update(batch_stats) report_stats.update(batch_stats) uidx += 1 if uidx % opt.report_every == 0: report_stats.output(eidx, uidx, opt.max_updates, total_stats.start_time) report_stats = Loss.Statistics() if uidx % opt.eval_every == 0: valid_stats = eval(model, crit, valid) # maybe adjust learning rate scheduler.step(valid_stats.ppl()) cur_lr = optimizer.param_groups[0]['lr'] print('Validation perplexity %d: %g' % (uidx, valid_stats.ppl())) print('Learning rate: %g' % cur_lr) if cur_lr < min_lr: print('Reaching minimum learning rate. Stop training!') estop = True break model_state_dict = model.state_dict() if eidx >= opt.start_checkpoint_at: checkpoint = { 'model': model_state_dict, 'opt': opt, 'dicts': dicts } # evaluate with BLEU score inference = Beam(opt, model) output_bpe = opt.save_model + '.bpe' output_txt = opt.save_model + '.txt' inference.translate(opt.valid_datasets[0], output_bpe) model.train() subprocess.call("sed 's/@@ //g' {:s} > {:s}" .format(output_bpe, output_txt), shell=True) ref = opt.valid_datasets[1][:-4] subprocess.call("sed 's/@@ //g' {:s} > {:s}" .format(opt.valid_datasets[1], ref), shell=True) cmd = "perl data/multi-bleu.perl {} < {}" \ .format(ref, output_txt) p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) \ .stdout.read().decode('utf-8') bleu = re.search("[\d]+.[\d]+", p) bleu = float(bleu.group()) print('Validation BLEU %d: %g' % (uidx, bleu)) if bleu > best_bleu: best_bleu = bleu torch.save(checkpoint, '%s_best.pt' % opt.save_model) print('Saved model: %d | BLEU %.2f' % (uidx, bleu)) if uidx >= opt.max_updates: print('Finishing after {:d} iterations!'.format(uidx)) estop = True break if estop: break