Exemple #1
0
def train(pretrain=False, kld_annealing=True):
    DEBUG = False
    print('Fine-tuning VNMT for GTE...')

    # Turn on training mode which enables dropout.
    model.train()
    iteration = 0
    total_loss = 0
    total_acc = 0
    # for plotting
    train_losses = []
    val_losses = []
    kld_values = [] # unweighted values
    kld_weights = []
    nlls = []

    ntokens = len(inputs.vocab)
    best_val_loss = float('inf')
    sents = [
        'People are celebrating a victory on the square.',
        'Two women who just had lunch hugging and saying goodbye.',
    ]

    if kld_annealing:
        kld_weight = kld_coef(iteration, batch_size)
    else:
        kld_weight = 1.0
    val_loss = evaluate(val_iter, model, ntokens, opt.batch_size, kld_weight=kld_weight)

    print('kld_annealing:')
    print(kld_annealing)
    print('Eavluating...')
    print(val_loss)
    example0 = create_example(inputs, sents[0], max_seq_len)
    example1 = create_example(inputs, sents[1], max_seq_len)
    print(model.generate(inputs, ntokens, example0, max_seq_len))
    print(model.generate(inputs, ntokens, example1, max_seq_len))

    # plot / dump check before proceeding with training
    kld_stats = { 'nll': nlls, 'kld_values': kld_values, 'kld_weights': kld_weights }
    with open('kld_stats.pkl', 'wb') as f:
        pickle.dump(kld_stats, f, protocol=pickle.HIGHEST_PROTOCOL)
    ##plot_losses([0, 1, 2, 3, 4], 'train', 'train_loss.pdf')


    start_time = time.time()
    sys.stdout.flush()
    for epoch in range(epochs):
        train_iter.init_epoch()
        n_correct, n_total = 0, 0
        total_loss = 0
        train_loss = 0

        for batch_idx, batch in enumerate(train_iter):
            # Turn on training mode which enables dropout.
            model.train()
            optimizer.zero_grad()
            s, s_lengths = batch.premise
            t, t_lengths = batch.hypothesis
            s = s.to(device)
            t = t.to(device)

            _nll, _kld = model.batchNLLLoss(s, s_lengths, t, t_lengths, device, train=True)

            # KLD Cost Annealing
            # ref: https://arxiv.org/pdf/1511.06349.pdf

            if kld_annealing:
                kld_weight = kld_coef(iteration, batch_size)
            else:
                kld_weight = 1.0
            _loss = _nll + kld_weight * _kld

            nlls.append(_nll.item())
            kld_values.append(_kld.item())
            kld_weights.append(kld_weight)

            _loss.backward()

            torch.nn.utils.clip_grad_norm(model.encoder_prior.parameters(), clip)
            torch.nn.utils.clip_grad_norm(model.encoder_post.parameters(), clip)
            torch.nn.utils.clip_grad_norm(model.decoder.parameters(), clip)
            #torch.nn.utils.clip_grad_norm(model.parameters(), clip)
            optimizer.step()


            batch_loss = _loss.item()
            total_loss += batch_loss
            train_loss += batch_loss

            if batch_idx % log_interval == 0 and batch_idx > 0:
                print('iteration: %d' % iteration)
                print('kld_weight: %.16f' % kld_weight)
                print('nll: %.16f' % _nll.item())
                print('kld_value: %.16f' % _kld.item())
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
                print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:03.3f} | ms/batch {:5.2f} | '
                        'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch_idx, len(train_iter), lr,
                    elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                start_time = time.time()
            iteration += 1

        print('Evalating...')
        val_loss = evaluate(val_iter, model, ntokens, opt.batch_size, kld_weight=kld_weight)
        val_losses.append(val_loss)
        print(val_loss)

        print(model.generate(inputs, ntokens, example0, max_seq_len))
        print(model.generate(inputs, ntokens, example1, max_seq_len))

        train_loss = train_loss / float(len(train_iter))
        print('Epoch train loss:')
        print(train_loss)
        train_losses.append(train_loss)

        # Save the model if the validation loss is the best we've seen so far.
        if val_loss < best_val_loss:
            with open('%s_%s_gte_best.pkl'%(model_name, rnn_type.lower()), 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            # Anneal the learning rate if no improvement has been seen in the validation dataset.
            #lr /= 4.0
            #print('lr annealed: %f'%lr)
            pass
        if epoch % save_interval == 0:
            with open('%s_%s_gte_e%d.pkl'%(model_name, rnn_type.lower(), epoch), 'wb') as f:
                torch.save(model, f)

        # save train/val loss lists
        with open('train_losses.pkl', 'wb') as f:
            pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
        with open('val_losses.pkl', 'wb') as f:
            pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
        kld_stats = { 'nll': nlls, 'kld_values': kld_values, 'kld_weights': kld_weights }
        with open('kld_stats.pkl', 'wb') as f:
            pickle.dump(kld_stats, f, protocol=pickle.HIGHEST_PROTOCOL)

        sys.stdout.flush()
        ##plot_losses(train_losses, 'train', 'train_loss.pdf')
        ##plot_losses(val_losses, 'validation', 'val_loss.pdf')
        ##show_plot(train_losses, val_losses, 'train-val_loss.pdf')
    print(train_losses)
    print(val_losses)

    # save train/val loss lists
    with open('train_losses.pickle', 'wb') as f:
        pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
    with open('val_losses.pickle', 'wb') as f:
        pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
model.embeddings.weight.data = inputs.vocab.vectors
model.embeddings.weight.requires_grad = False

model = torch.load('vnmt_gru_gte_best.pkl')
##model = torch.load('vnmt_pretrain_gru_gte_best.pkl')
print(type(model))

sents = [
    'people are celebrating a victory on the square .',
    'two women who just had lunch hugging and saying goodbye .',
    'a man selling donuts to a customer during a world exhibition event .',
    'two men and a woman finishing a meal and drinks .',
    'people are running away from the bear .',
    'a boy is jumping on skateboard in the middle of a red bridge .',
    'a big brown dog swims towards the camera .',
    'a small group of church-goers watch a choir practice .',
]

i = 0
example0 = create_example(inputs, sents[0], max_seq_len)
example1 = create_example(inputs, sents[1], max_seq_len)
for i, sent in enumerate(sents):
    sent = sent + ' <pad>'
    example = create_example(inputs, sent, max_seq_len)
    print(example)
    output, attns = model.generate(inputs, ntokens, example, max_seq_len)
    show_attention('attn_vis%d.pdf' % i, sent, output, attns)
    show_attention('attn_vis%d.png' % i, sent, output, attns)
    ##output, attns = model.generate(inputs, ntokens, example, max_seq_len, device)
    ##show_attention('attn_pretrain_vis%d'%i, sent, output, attns)
Exemple #3
0
def train(reverse=False, pretrain=False):
    print('Pretraining VNMT for GTE...')
    model.train()
    total_loss = 0
    total_acc = 0
    train_losses = []  # for plotting
    val_losses = []
    attn_weights = [[], []]

    ntokens = len(inputs.vocab)
    best_val_loss = float('inf')
    if reverse:
        sents = [
            'People are celebrating a birthday.',
            'There are two woman in this picture.'  #'Two women who just had lunch hugging and saying goodbye.',
        ]
    else:
        sents = [
            'People are celebrating a victory on the square.',
            'Two women who just had lunch hugging and saying goodbye.',
        ]
    val_loss = evaluate(val_iter, model, ntokens, opt.batch_size)

    example0 = create_example(inputs, sents[0], max_seq_len)
    example1 = create_example(inputs, sents[1], max_seq_len)
    for i, sent in enumerate(sents):
        sent = '<sos> ' + sent + ' <pad>'
        example = create_example(inputs, sent, max_seq_len)
        output, attns = model.generate(inputs, ntokens, example, max_seq_len,
                                       device)
        ##show_attention('attn_vis%d'%i, sent, output, attns)
        attn_weights[i].append((output, attns))

    start_time = time.time()
    iteration = 0
    sys.stdout.flush()
    for epoch in range(epochs):
        train_iter.init_epoch()
        n_correct, n_total = 0, 0
        total_loss = 0
        train_loss = 0

        for batch_idx, batch in enumerate(train_iter):
            # Turn on training mode which enables dropout.
            model.train()
            optimizer.zero_grad()
            s, s_lengths = batch.premise
            t, t_lengths = batch.hypothesis
            s = s.to(device)
            t = t.to(device)

            if reverse:
                _loss = model.batchNLLLoss(t,
                                           t_lengths,
                                           s,
                                           s_lengths,
                                           device,
                                           train=False)
            else:
                _loss = model.batchNLLLoss(s,
                                           s_lengths,
                                           t,
                                           t_lengths,
                                           device,
                                           train=False)

            _loss.backward()
            optimizer.step()
            loss = _loss.item()
            total_loss += loss
            train_loss += loss
            iteration += 1

            if batch_idx % log_interval == 0 and batch_idx > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
                print(
                    '| epoch {:3d} | {:5d}/{:5d} batches | lr {:03.3f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                        epoch, batch_idx, len(train_iter), lr,
                        elapsed * 1000 / log_interval, cur_loss,
                        math.exp(cur_loss)))
                total_loss = 0
                start_time = time.time()

        print('Evalating...')
        val_loss = evaluate(val_iter, model, ntokens, opt.batch_size)
        val_losses.append(val_loss)
        for i, sent in enumerate(sents):
            sent = '<sos> ' + sent + ' <pad>'
            example = create_example(inputs, sent, max_seq_len)
            output, attns = model.generate(inputs, ntokens, example,
                                           max_seq_len, device)
            ##show_attention('attn_vis%d_%d'%(epoch,i), sent, output, attns)
            attn_weights[i].append((output, attns))

        train_loss = train_loss / float(len(train_iter))
        print('Epoch train loss:')
        print(train_loss)
        train_losses.append(train_loss)

        # Save the model if the validation loss is the best we've seen so far.
        if val_loss < best_val_loss:
            with open('%s_%s_gte_best.pkl' % (model_name, rnn_type.lower()),
                      'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            # Anneal the learning rate if no improvement has been seen in the validation dataset.
            #lr /= 4.0
            #print('lr annealed: %f'%lr)
            pass
        if epoch % save_interval == 0:
            with open(
                    '%s_%s_gte_e%d.pkl' %
                (model_name, rnn_type.lower(), epoch), 'wb') as f:
                torch.save(model, f)

        # save train/val loss lists
        with open('train_losses.pkl', 'wb') as f:
            pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
        with open('val_losses.pkl', 'wb') as f:
            pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
        with open('attn_weights.pkl', 'wb') as f:
            pickle.dump(attn_weights, f, protocol=pickle.HIGHEST_PROTOCOL)
        ##plot_losses(train_losses, 'train', 'train_loss.pdf')
        ##plot_losses(val_losses, 'validation', 'val_loss.pdf')
        ##show_plot(train_losses, val_losses, 'train-val_loss.pdf')
        sys.stdout.flush()

    # Print the loss history just in case
    print(train_losses)
    print(val_losses)

    # save train/val loss lists
    with open('train_losses.pickle', 'wb') as f:
        pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
    with open('val_losses.pickle', 'wb') as f:
        pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
    with open('attn_weights.pkl', 'wb') as f:
        pickle.dump(attn_weights, f, protocol=pickle.HIGHEST_PROTOCOL)
Exemple #4
0
def train(pretrain=False, kld_annealing=True):
    DEBUG = False
    print('gte_vae.train')
    print('lr=%F' % lr)

    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0
    total_acc = 0
    # for plotting
    train_losses = []
    val_losses = []
    kld_values = []  # unweighted values
    kld_weights = []
    nlls = []

    ntokens = len(inputs.vocab)
    best_val_loss = float('inf')

    sents = [
        'People are celebrating a victory on the square.',
        'Two women who just had lunch hugging and saying goodbye.',
    ]

    iteration = 0
    if kld_annealing:

        kld_weight = kld_coef(iteration, batch_size)
    else:
        kld_weight = 1.0
    val_loss = evaluate(val_iter,
                        model,
                        ntokens,
                        opt.batch_size,
                        kld_weight=kld_weight)
    val_loss = val_loss.data[0]

    print('kld_annealing:')
    print(kld_annealing)
    print('Eavluating...')
    print(val_loss)
    example0 = create_example(inputs, sents[0], max_seq_len)
    example1 = create_example(inputs, sents[1], max_seq_len)
    print(model.generate(inputs, ntokens, example0, max_seq_len))
    print(model.generate(inputs, ntokens, example1, max_seq_len))

    start_time = time.time()

    # plot / dump check before proceeding with training
    kld_stats = {
        'nll': nlls,
        'kld_values': kld_values,
        'kld_weights': kld_weights
    }
    with open('kld_stats.pkl', 'wb') as f:
        pickle.dump(kld_stats, f, protocol=pickle.HIGHEST_PROTOCOL)
    plot_losses([0, 1, 2, 3, 4], 'train', 'train_loss.eps')

    for epoch in range(epochs):
        train_iter.init_epoch()
        n_correct, n_total = 0, 0
        total_loss = 0
        train_loss = 0

        for batch_idx, batch in enumerate(train_iter):
            # Turn on training mode which enables dropout.
            model.train()
            model.encoder_prior.train()
            model.encoder_post.train()
            model.decoder.train()
            optimizer.zero_grad()

            #print(batch.text.data.shape) # 35 x 64
            #batch.text.data = batch.text.data.view(-1, max_seq_len) # -1 instead of opt.batch_size to avoid reshaping err at the end of the epoch
            batch.premise.data = batch.premise.data.transpose(
                1, 0)  # should be 64x35 [batch_size x seq_len]
            batch.hypothesis.data = batch.hypothesis.data.transpose(
                1, 0)  # should be 64x35 [batch_size x seq_len]
            #nll, kld = model.batchNLLLoss(batch.premise, batch.hypothesis)
            nll, kld = model.batchNLLLoss(batch.premise,
                                          batch.hypothesis,
                                          train=True)

            # KLD Cost Annealing
            # ref: https://arxiv.org/pdf/1511.06349.pdf
            iteration += 1
            if kld_annealing:
                kld_weight = kld_coef(iteration, batch_size)
            else:
                kld_weight = 1.0
            loss = nll + kld_weight * kld

            nlls.append(nll.data)
            kld_values.append(kld.data)
            kld_weights.append(kld_weight)

            loss.backward()
            torch.nn.utils.clip_grad_norm(model.encoder_prior.parameters(),
                                          clip)
            torch.nn.utils.clip_grad_norm(model.encoder_post.parameters(),
                                          clip)
            torch.nn.utils.clip_grad_norm(model.decoder.parameters(), clip)
            #torch.nn.utils.clip_grad_norm(model.parameters(), clip)
            optimizer.step()

            batch_loss = loss.data
            total_loss += batch_loss
            train_loss += batch_loss

            if batch_idx % log_interval == 0 and batch_idx > 0:
                print('iteration: %d' % iteration)
                print('kld_weight: %.16f' % kld_weight)
                print('nll: %.16f' % nll.data[0])
                print('kld_value: %.16f' % kld.data[0])
                cur_loss = total_loss[0] / log_interval
                elapsed = time.time() - start_time
                print(
                    '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                        epoch, batch_idx,
                        len(train_iter) // max_seq_len, lr,
                        elapsed * 1000 / log_interval, cur_loss,
                        0))  #math.exp(cur_loss)
                total_loss = 0
                start_time = time.time()

        print('Evalating...')
        val_loss = evaluate(val_iter,
                            model,
                            ntokens,
                            opt.batch_size,
                            kld_weight=kld_weight)
        print(val_loss.data[0])
        print(model.generate(inputs, ntokens, example0, max_seq_len))
        print(model.generate(inputs, ntokens, example1, max_seq_len))

        print(nlls[-1])
        print(kld_values[-1])
        print(kld_weights[-1])
        print('Epoch train loss:')
        print(train_loss[0])
        train_loss = train_loss / float(len(train_iter))
        print(train_loss[0])
        train_losses.append(train_loss[0])

        val_loss = evaluate(val_iter,
                            model,
                            ntokens,
                            opt.batch_size,
                            kld_weight=kld_weight)
        val_loss = val_loss.data[0]
        val_losses.append(val_loss)
        # Save the model if the validation loss is the best we've seen so far.
        if val_loss < best_val_loss:
            with open('%s_%s_gte_best.pkl' % (model_name, rnn_type.lower()),
                      'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            # Anneal the learning rate if no improvement has been seen in the validation dataset.
            #lr /= 4.0
            #print('lr annealed: %f'%lr)
            pass
        if epoch % save_interval == 0:
            with open(
                    '%s_%s_gte_e%d.pkl' %
                (model_name, rnn_type.lower(), epoch), 'wb') as f:
                torch.save(model, f)

        # save train/val loss lists
        with open('train_losses.pkl', 'wb') as f:
            pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
        with open('val_losses.pkl', 'wb') as f:
            pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
        kld_stats = {
            'nll': nlls,
            'kld_values': kld_values,
            'kld_weights': kld_weights
        }
        with open('kld_stats.pkl', 'wb') as f:
            pickle.dump(kld_stats, f, protocol=pickle.HIGHEST_PROTOCOL)

        plot_losses(train_losses, 'train', 'train_loss.eps')
        plot_losses(val_losses, 'validation', 'val_loss.eps')
        show_plot(train_losses, val_losses, 'train-val_loss.eps')

    print(train_losses)
    print(val_losses)

    # save train/val loss lists
    with open('train_losses.pickle', 'wb') as f:
        pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
    with open('val_losses.pickle', 'wb') as f:
        pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
    show_plot(train_losses, val_losses, 'train-val_loss.eps')
Exemple #5
0
def train(reverse=False, pretrain=False):
    DEBUG = False
    print('gte_vae_pretrain.train')

    model.train()
    total_loss = 0
    total_acc = 0
    # for plotting
    train_losses = []
    val_losses = []

    ntokens = len(inputs.vocab)
    best_val_loss = float('inf')
    if reverse:
        sents = [
            'People are celebrating a birthday.',
            'There are two woman in this picture.'  #'Two women who just had lunch hugging and saying goodbye.',
        ]
    else:
        sents = [
            'People are celebrating a victory on the square.',
            'Two women who just had lunch hugging and saying goodbye.',
        ]
    val_loss = evaluate(val_iter, model, ntokens, opt.batch_size)
    val_loss = val_loss.data[0]
    print(val_loss)
    example0 = create_example(inputs, sents[0], max_seq_len)
    example1 = create_example(inputs, sents[1], max_seq_len)
    for i, sent in enumerate(sents):
        sent = '<sos> ' + sent + ' <pad>'
        example = create_example(inputs, sent, max_seq_len)
        output, attns = model.generate(inputs, ntokens, example, max_seq_len)
        show_attention('attn_vis%d' % i, sent, output, attns)

    start_time = time.time()
    iteration = 0

    for epoch in range(epochs):
        train_iter.init_epoch()
        n_correct, n_total = 0, 0
        total_loss = 0
        train_loss = 0

        for batch_idx, batch in enumerate(train_iter):
            # Turn on training mode which enables dropout.
            model.train()
            optimizer.zero_grad()

            #print(batch.text.data.shape) # 35 x 64
            #batch.text.data = batch.text.data.view(-1, max_seq_len) # -1 instead of opt.batch_size to avoid reshaping err at the end of the epoch
            batch.premise.data = batch.premise.data.transpose(
                1, 0)  # should be 64x35 [batch_size x seq_len]
            batch.hypothesis.data = batch.hypothesis.data.transpose(
                1, 0)  # should be 64x35 [batch_size x seq_len]
            if reverse:
                loss = model.batchNLLLoss(batch.hypothesis,
                                          batch.premise,
                                          train=True)
            else:
                loss = model.batchNLLLoss(batch.premise,
                                          batch.hypothesis,
                                          train=True)

            iteration += 1
            loss.backward()
            optimizer.step()
            batch_loss = loss.data
            total_loss += batch_loss
            train_loss += batch_loss

            if batch_idx % log_interval == 0 and batch_idx > 0:
                cur_loss = total_loss[0] / log_interval
                elapsed = time.time() - start_time
                print(
                    '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                        epoch, batch_idx,
                        len(train_iter) // max_seq_len, lr,
                        elapsed * 1000 / log_interval, cur_loss,
                        0))  #math.exp(cur_loss)
                total_loss = 0
                start_time = time.time()

        print('Evalating...')
        val_loss = evaluate(val_iter, model, ntokens, opt.batch_size)
        print(val_loss.data[0])
        for i, sent in enumerate(sents):
            sent = '<sos> ' + sent + ' <pad>'
            example = create_example(inputs, sent, max_seq_len)
            output, attns = model.generate(inputs, ntokens, example,
                                           max_seq_len)
            show_attention('attn_vis%d_%d' % (epoch, i), sent, output, attns)

        print('Epoch train loss:')
        print(train_loss[0])
        train_loss = train_loss / float(len(train_iter))
        print(train_loss[0])
        train_losses.append(train_loss[0])

        val_loss = evaluate(val_iter, model, ntokens, opt.batch_size)
        val_loss = val_loss.data[0]
        val_losses.append(val_loss)
        # Save the model if the validation loss is the best we've seen so far.
        if val_loss < best_val_loss:
            with open('%s_%s_gte_best.pkl' % (model_name, rnn_type.lower()),
                      'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            # Anneal the learning rate if no improvement has been seen in the validation dataset.
            #lr /= 4.0
            #print('lr annealed: %f'%lr)
            pass
        if epoch % save_interval == 0:
            with open(
                    '%s_%s_gte_e%d.pkl' %
                (model_name, rnn_type.lower(), epoch), 'wb') as f:
                torch.save(model, f)

        # save train/val loss lists
        with open('train_losses.pkl', 'wb') as f:
            pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
        with open('val_losses.pkl', 'wb') as f:
            pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
        plot_losses(train_losses, 'train', 'train_loss.eps')
        plot_losses(val_losses, 'validation', 'val_loss.eps')
        show_plot(train_losses, val_losses, 'train-val_loss.eps')

    print(train_losses)
    print(val_losses)

    # save train/val loss lists
    with open('train_losses.pickle', 'wb') as f:
        pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
    with open('val_losses.pickle', 'wb') as f:
        pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
    show_plot(train_losses, val_losses, 'train-val_loss.eps')