예제 #1
0
def valid(valid_loader, seq2seq, epoch):
    seq2seq.eval()
    total_loss_t = 0
    for num, (test_index, test_in, test_in_len,
              test_out) in enumerate(valid_loader):
        #test_in = test_in.unsqueeze(1)
        test_in, test_out = Variable(test_in, volatile=True).cuda(), Variable(
            test_out, volatile=True).cuda()
        output_t, attn_weights_t = seq2seq(test_in,
                                           test_out,
                                           test_in_len,
                                           teacher_rate=False,
                                           train=False)
        batch_count_n = writePredict(epoch, test_index, output_t, 'valid')
        test_label = test_out.permute(1, 0)[1:].contiguous().view(-1)
        #loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
        #                         test_label, ignore_index=tokens['PAD_TOKEN'])
        #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label)
        if LABEL_SMOOTH:
            loss_t = crit(log_softmax(output_t.view(-1, vocab_size)),
                          test_label)
        else:
            loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
                                     test_label,
                                     ignore_index=tokens['PAD_TOKEN'])

        total_loss_t += loss_t.data[0]

        if 'n04-015-00-01,171' in test_index:
            b = test_index.tolist().index('n04-015-00-01,171')
            visualizeAttn(test_in.data[b, 0], test_in_len[0],
                          [j[b] for j in attn_weights_t], epoch,
                          batch_count_n[b], 'valid_n04-015-00-01')
    total_loss_t /= (num + 1)
    return total_loss_t
def test(test_loader, modelID, showAttn=True):
    encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP,
                      FLIP).cuda()
    decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention,
                      TRADEOFF_CONTEXT_EMBED).cuda()
    seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).cuda()
    model_file = 'save_weights/seq2seq-' + str(modelID) + '.model'
    pretrain_dict = torch.load(model_file)
    seq2seq_dict = seq2seq.state_dict()
    pretrain_dict = {
        k: v
        for k, v in pretrain_dict.items() if k in seq2seq_dict
    }
    seq2seq_dict.update(pretrain_dict)
    seq2seq.load_state_dict(seq2seq_dict)  #load
    print('Loading ' + model_file)

    seq2seq.eval()
    total_loss_t = 0
    start_t = time.time()
    for num, (test_index, test_in, test_in_len, test_out,
              test_domain) in enumerate(test_loader):
        lambd = LAMBD
        test_in, test_out = Variable(test_in, volatile=True).cuda(), Variable(
            test_out, volatile=True).cuda()
        test_domain = Variable(test_domain, volatile=True).cuda()
        output_t, attn_weights_t, out_domain_t = seq2seq(test_in,
                                                         test_out,
                                                         test_in_len,
                                                         lambd,
                                                         teacher_rate=False,
                                                         train=False)
        batch_count_n = writePredict(modelID, test_index, output_t, 'test')
        test_label = test_out.permute(1, 0)[1:].contiguous().view(-1)
        if LABEL_SMOOTH:
            loss_t = crit(log_softmax(output_t.view(-1, vocab_size)),
                          test_label)
        else:
            loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
                                     test_label,
                                     ignore_index=tokens['PAD_TOKEN'])

        total_loss_t += loss_t.data[0]
        if showAttn:
            global_index_t = 0
            for t_idx, t_in in zip(test_index, test_in):
                visualizeAttn(t_in.data[0], test_in_len[0],
                              [j[global_index_t] for j in attn_weights_t],
                              modelID, batch_count_n[global_index_t],
                              'test_' + t_idx.split(',')[0])
                global_index_t += 1

    total_loss_t /= (num + 1)
    writeLoss(total_loss_t, 'test')
    print('       TEST loss=%.3f, time=%.3f' %
          (total_loss_t, time.time() - start_t))
예제 #3
0
def test(test_loader, modelID, showAttn=True):
    encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP,
                      FLIP).to(device)
    decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention,
                      TRADEOFF_CONTEXT_EMBED).to(device)
    seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).to(device)
    model_file = 'save_weights/seq2seq-' + str(modelID) + '.model'
    print('Loading ' + model_file)
    seq2seq.load_state_dict(torch.load(model_file))  #load

    seq2seq.eval()
    total_loss_t = 0
    start_t = time.time()
    with torch.no_grad():
        for num, (test_index, test_in, test_in_len,
                  test_out) in enumerate(test_loader):
            #test_in = test_in.unsqueeze(1)
            test_in, test_out = test_in.to(device), test_out.to(device)
            if test_in.requires_grad or test_out.requires_grad:
                print(
                    'ERROR! test_in, test_out should have requires_grad=False')
            output_t, attn_weights_t = seq2seq(test_in,
                                               test_out,
                                               test_in_len,
                                               teacher_rate=False,
                                               train=False)
            batch_count_n = writePredict(modelID, test_index, output_t, 'test')
            test_label = test_out.permute(1, 0)[1:].reshape(-1)
            #loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
            #                        test_label, ignore_index=tokens['PAD_TOKEN'])
            #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label)
            if LABEL_SMOOTH:
                loss_t = crit(log_softmax(output_t.reshape(-1, vocab_size)),
                              test_label)
            else:
                loss_t = F.cross_entropy(output_t.reshape(-1, vocab_size),
                                         test_label,
                                         ignore_index=tokens['PAD_TOKEN'])

            total_loss_t += loss_t.item()

            if showAttn:
                global_index_t = 0
                for t_idx, t_in in zip(test_index, test_in):
                    visualizeAttn(t_in.detach()[0], test_in_len[0],
                                  [j[global_index_t] for j in attn_weights_t],
                                  modelID, batch_count_n[global_index_t],
                                  'test_' + t_idx.split(',')[0])
                    global_index_t += 1

        total_loss_t /= (num + 1)
        writeLoss(total_loss_t, 'test')
        print('    TEST loss=%.3f, time=%.3f' %
              (total_loss_t, time.time() - start_t))
def train(train_loader, seq2seq, opt, teacher_rate, epoch, lambd):
    seq2seq.train()
    total_loss = 0
    total_loss_d = 0
    for num, (train_index, train_in, train_in_len, train_out,
              train_domain) in enumerate(train_loader):
        train_in, train_out = Variable(train_in).cuda(), Variable(
            train_out).cuda()
        train_domain = Variable(train_domain).cuda()
        output, attn_weights, out_domain = seq2seq(
            train_in,
            train_out,
            train_in_len,
            lambd,
            teacher_rate=teacher_rate,
            train=True)  # (100-1, 32, 62+1)
        batch_count_n = writePredict(epoch, train_index, output, 'train')
        train_label = train_out.permute(1, 0)[1:].contiguous().view(
            -1)  #remove<GO>
        output_l = output.view(-1, vocab_size)  # remove last <EOS>

        if VISUALIZE_TRAIN:
            if 'e02-074-03-00,191' in train_index:
                b = train_index.tolist().index('e02-074-03-00,191')
                visualizeAttn(train_in.data[b, 0], train_in_len[0],
                              [j[b] for j in attn_weights], epoch,
                              batch_count_n[b], 'train_e02-074-03-00')

        if LABEL_SMOOTH:
            loss = crit(log_softmax(output_l.view(-1, vocab_size)),
                        train_label)
        else:
            loss = F.cross_entropy(output_l.view(-1, vocab_size),
                                   train_label,
                                   ignore_index=tokens['PAD_TOKEN'])

        loss2 = F.cross_entropy(out_domain, train_domain)
        loss2 = ALPHA * loss2
        loss_total = loss + loss2
        opt.zero_grad()
        loss_total.backward()
        opt.step()
        total_loss += loss.data[0]
        total_loss_d += loss2.data[0]

    total_loss /= (num + 1)
    total_loss_d /= (num + 1)
    return total_loss, total_loss_d
예제 #5
0
def train(train_loader, seq2seq, opt, teacher_rate, epoch):
    seq2seq.train()
    total_loss = 0
    for num, (train_index, train_in, train_in_len,
              train_out) in enumerate(train_loader):
        #train_in = train_in.unsqueeze(1)
        train_in, train_out = Variable(train_in).cuda(), Variable(
            train_out).cuda()
        output, attn_weights = seq2seq(train_in,
                                       train_out,
                                       train_in_len,
                                       teacher_rate=teacher_rate,
                                       train=True)  # (100-1, 32, 62+1)
        batch_count_n = writePredict(epoch, train_index, output, 'train')
        train_label = train_out.permute(1, 0)[1:].contiguous().view(
            -1)  #remove<GO>
        output_l = output.view(-1, vocab_size)  # remove last <EOS>

        if VISUALIZE_TRAIN:
            if 'e02-074-03-00,191' in train_index:
                b = train_index.tolist().index('e02-074-03-00,191')
                visualizeAttn(train_in.data[b, 0], train_in_len[0],
                              [j[b] for j in attn_weights], epoch,
                              batch_count_n[b], 'train_e02-074-03-00')

        #loss = F.cross_entropy(output_l.view(-1, vocab_size),
        #                       train_label, ignore_index=tokens['PAD_TOKEN'])
        #loss = loss_label_smoothing(output_l.view(-1, vocab_size), train_label)
        if LABEL_SMOOTH:
            loss = crit(log_softmax(output_l.view(-1, vocab_size)),
                        train_label)
        else:
            loss = F.cross_entropy(output_l.view(-1, vocab_size),
                                   train_label,
                                   ignore_index=tokens['PAD_TOKEN'])
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.item()
        print(f'Batch {num} loss: {loss.item()}')

    total_loss /= (num + 1)
    return total_loss
예제 #6
0
def valid(valid_loader, seq2seq, epoch):
    seq2seq.eval()
    total_loss_t = 0
    with torch.no_grad():
        for num, (test_index, test_in, test_in_len,
                  test_out) in enumerate(valid_loader):
            #test_in = test_in.unsqueeze(1)
            test_in, test_out = test_in.to(device), test_out.to(device)
            if test_in.requires_grad or test_out.requires_grad:
                print(
                    'ERROR! test_in, test_out should have requires_grad=False')
            output_t, attn_weights_t = seq2seq(test_in,
                                               test_out,
                                               test_in_len,
                                               teacher_rate=False,
                                               train=False)
            batch_count_n = writePredict(epoch, test_index, output_t, 'valid')
            test_label = test_out.permute(1, 0)[1:].reshape(-1)
            #loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
            #                         test_label, ignore_index=tokens['PAD_TOKEN'])
            #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label)
            if LABEL_SMOOTH:
                loss_t = crit(log_softmax(output_t.reshape(-1, vocab_size)),
                              test_label)
            else:
                loss_t = F.cross_entropy(output_t.reshape(-1, vocab_size),
                                         test_label,
                                         ignore_index=tokens['PAD_TOKEN'])

            total_loss_t += loss_t.item()

            if 'n04-015-00-01,171' in test_index:
                b = test_index.tolist().index('n04-015-00-01,171')
                visualizeAttn(test_in.detach()[b, 0], test_in_len[0],
                              [j[b] for j in attn_weights_t], epoch,
                              batch_count_n[b], 'valid_n04-015-00-01')
        total_loss_t /= (num + 1)
    return total_loss_t