Exemplo n.º 1
0
def eval(model, criterion, valid_data):
    stats = Loss.Statistics()
    model.eval()
    loss = Loss.LossCompute(model.generator, criterion)
    for src, tgt in valid_data:
        src, tgt, src_lengths = prepare_data(src, tgt, True)
        outputs = model(src, tgt[:-1], src_lengths)
        gen_state = loss.make_loss_batch(outputs, tgt[1:])
        _, batch_stats = loss.compute_loss(**gen_state)
        stats.update(batch_stats)
    model.train()
    return stats
Exemplo n.º 2
0
def train(opt):
    print('| build data iterators')
    train = TextIterator(*opt.datasets, *opt.dicts,
                         src_vocab_size=opt.src_vocab_size,
                         tgt_vocab_size=opt.tgt_vocab_size,
                         batch_size=opt.batch_size,
                         max_seq_length=opt.max_seq_length)

    valid = TextIterator(*opt.valid_datasets, *opt.dicts,
                         src_vocab_size=opt.src_vocab_size,
                         tgt_vocab_size=opt.tgt_vocab_size,
                         batch_size=opt.batch_size,
                         max_seq_length=opt.max_seq_length)

    if opt.src_vocab_size < 0:
        opt.src_vocab_size = len(train.source_dict)
    if opt.tgt_vocab_size < 0:
        opt.tgt_vocab_size = len(train.target_dict)

    print('| vocabulary size. source = %d; target = %d' %
          (opt.src_vocab_size, opt.tgt_vocab_size))
    dicts = [train.source_dict, train.target_dict]

    crit = Loss.nmt_criterion(opt.tgt_vocab_size, 0).to(device)
    if opt.train_from != '':
        print('| Load trained model!')
        checkpoint = torch.load(opt.train_from)
        model = models.make_base_model(opt, checkpoint)
    else:
        model = models.make_base_model(opt)
        init_uniform(model)
    model.to(device)
    if opt.encoder_type in ["sabrnn", "fabrnn"]:
        print('Add punctuation constrain!')
        model.encoder.punct(train.src_punct)
    print(model)
    model.dicts = dicts
    check_model_path()
    tally_parameters(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, 'min',
                                  factor=opt.learning_rate_decay,
                                  patience=0)
    uidx = 0  # number of updates
    estop = False
    min_lr = opt.learning_rate * math.pow(opt.learning_rate_decay, 5)
    best_bleu = -1
    for eidx in range(1, opt.epochs + 1):
        closs = Loss.LossCompute(model.generator, crit)
        tot_loss = 0
        total_stats = Loss.Statistics()
        report_stats = Loss.Statistics()
        for x, y in train:
            model.zero_grad()
            src, tgt, lengths_x = prepare_data(x, y)
            out = model(src, tgt[:-1], lengths_x)
            gen_state = closs.make_loss_batch(out, tgt[1:])
            shard_size = opt.max_generator_batches
            batch_size = len(lengths_x)
            batch_stats = Loss.Statistics()
            for shard in Loss.shards(gen_state, shard_size):
                loss, stats = closs.compute_loss(**shard)
                loss.div(batch_size).backward()
                batch_stats.update(stats)
                tot_loss += loss.item()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           opt.max_grad_norm)
            optimizer.step()
            total_stats.update(batch_stats)
            report_stats.update(batch_stats)
            uidx += 1
            if uidx % opt.report_every == 0:
                report_stats.output(eidx, uidx, opt.max_updates,
                                    total_stats.start_time)
                report_stats = Loss.Statistics()

            if uidx % opt.eval_every == 0:
                valid_stats = eval(model, crit, valid)
                # maybe adjust learning rate
                scheduler.step(valid_stats.ppl())
                cur_lr = optimizer.param_groups[0]['lr']
                print('Validation perplexity %d: %g' %
                      (uidx, valid_stats.ppl()))
                print('Learning rate: %g' % cur_lr)
                if cur_lr < min_lr:
                    print('Reaching minimum learning rate. Stop training!')
                    estop = True
                    break
                model_state_dict = model.state_dict()
                if eidx >= opt.start_checkpoint_at:
                    checkpoint = {
                        'model': model_state_dict,
                        'opt': opt,
                        'dicts': dicts
                    }

                    # evaluate with BLEU score
                    inference = Beam(opt, model)
                    output_bpe = opt.save_model + '.bpe'
                    output_txt = opt.save_model + '.txt'
                    inference.translate(opt.valid_datasets[0], output_bpe)
                    model.train()
                    subprocess.call("sed 's/@@ //g' {:s} > {:s}"
                                    .format(output_bpe, output_txt),
                                    shell=True)
                    ref = opt.valid_datasets[1][:-4]
                    subprocess.call("sed 's/@@ //g' {:s} > {:s}"
                                    .format(opt.valid_datasets[1], ref),
                                    shell=True)
                    cmd = "perl data/multi-bleu.perl {} < {}" \
                        .format(ref, output_txt)
                    p = subprocess.Popen(cmd,
                                         shell=True,
                                         stdout=subprocess.PIPE) \
                        .stdout.read().decode('utf-8')
                    bleu = re.search("[\d]+.[\d]+", p)
                    bleu = float(bleu.group())
                    print('Validation BLEU %d: %g' % (uidx, bleu))
                    if bleu > best_bleu:
                        best_bleu = bleu
                        torch.save(checkpoint, '%s_best.pt' % opt.save_model)
                        print('Saved model: %d | BLEU %.2f' % (uidx, bleu))

            if uidx >= opt.max_updates:
                print('Finishing after {:d} iterations!'.format(uidx))
                estop = True
                break
        if estop:
            break