Ejemplo n.º 1
0
def evaluate(model, valid_loader, criterion, epoch):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        for num, (valid_index, valid_in, valid_in_len,
                  valid_out) in enumerate(valid_loader):
            valid_in, valid_out = Variable(valid_in).cuda(), Variable(
                valid_out).cuda()
            output_e = model(valid_in,
                             valid_out,
                             valid_in_len,
                             teacher_rate=False,
                             train=False)

            batch_count_n = writePredict(epoch, valid_index, output_e, 'valid')

            valid_label = valid_out.permute(1, 0)[1:].contiguous().view(
                -1)  #remove<GO>
            output = output_e.view(-1, vocab_size)
            if LABEL_SMOOTH:
                loss = crit(log_softmax(output), valid_label)
            else:
                loss = criterion(output, valid_label)
            #print("valid batch loss", loss.item())

            epoch_loss += loss.item()

        epoch_loss /= (num + 1)
    return epoch_loss
Ejemplo n.º 2
0
def valid(valid_loader, seq2seq, epoch):
    seq2seq.eval()
    total_loss_t = 0
    for num, (test_index, test_in, test_in_len,
              test_out) in enumerate(valid_loader):
        #test_in = test_in.unsqueeze(1)
        test_in, test_out = Variable(test_in, volatile=True).cuda(), Variable(
            test_out, volatile=True).cuda()
        output_t, attn_weights_t = seq2seq(test_in,
                                           test_out,
                                           test_in_len,
                                           teacher_rate=False,
                                           train=False)
        batch_count_n = writePredict(epoch, test_index, output_t, 'valid')
        test_label = test_out.permute(1, 0)[1:].contiguous().view(-1)
        #loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
        #                         test_label, ignore_index=tokens['PAD_TOKEN'])
        #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label)
        if LABEL_SMOOTH:
            loss_t = crit(log_softmax(output_t.view(-1, vocab_size)),
                          test_label)
        else:
            loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
                                     test_label,
                                     ignore_index=tokens['PAD_TOKEN'])

        total_loss_t += loss_t.data[0]

        if 'n04-015-00-01,171' in test_index:
            b = test_index.tolist().index('n04-015-00-01,171')
            visualizeAttn(test_in.data[b, 0], test_in_len[0],
                          [j[b] for j in attn_weights_t], epoch,
                          batch_count_n[b], 'valid_n04-015-00-01')
    total_loss_t /= (num + 1)
    return total_loss_t
Ejemplo n.º 3
0
def train(model, train_loader, optimizer, criterion, clip, epoch,
          teacher_rate):
    model.train()

    epoch_loss = 0
    #with autograd.detect_anomaly():

    for num, (train_index, train_in, train_in_len,
              train_out) in enumerate(train_loader):

        train_in, train_out = Variable(train_in).cuda(), Variable(
            train_out).cuda()
        optimizer.zero_grad()
        output_t = model(train_in,
                         train_out,
                         train_in_len,
                         teacher_rate=teacher_rate,
                         train=True)
        batch_count_n = writePredict(epoch, train_index, output_t, 'train')
        #trg = [trg len, batch size]
        #output = [trg len, batch size, output dim]

        train_label = train_out.permute(1, 0)[1:].contiguous().view(
            -1)  #remove<GO>
        output = output_t.view(-1, vocab_size)

        #trg = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]

        loss = criterion(output, train_label)

        loss.backward()

        #save last layer decoder fc and encoder  gradient
        for name, param in model.named_parameters():
            if name == 'decoder.fc_out.weight':
                writeGradient(name, torch.linalg.norm(param.grad, 2))
            if name == 'encoder.rnn.weight_hh_l2':
                writeGradient(name, torch.linalg.norm(param.grad, 2))
            if name == 'encoder.rnn.weight_hh_l1':
                writeGradient(name, torch.linalg.norm(param.grad, 2))
            if name == 'encoder.rnn.weight_hh_l0':
                writeGradient(name, torch.linalg.norm(param.grad, 2))
            if name == 'encoder.layer2.2.weight':
                writeGradient(name, torch.linalg.norm(param.grad, 2))
            if name == 'encoder.layer1.2.weight':
                writeGradient(name, torch.linalg.norm(param.grad, 2))
            if name == 'encoder.layer0.2.weight':
                writeGradient(name, torch.linalg.norm(param.grad, 2))

        #clip the gradient norm by 2
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= (num + 1)

    return epoch_loss
def test(test_loader, modelID, showAttn=True):
    encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP,
                      FLIP).cuda()
    decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention,
                      TRADEOFF_CONTEXT_EMBED).cuda()
    seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).cuda()
    model_file = 'save_weights/seq2seq-' + str(modelID) + '.model'
    pretrain_dict = torch.load(model_file)
    seq2seq_dict = seq2seq.state_dict()
    pretrain_dict = {
        k: v
        for k, v in pretrain_dict.items() if k in seq2seq_dict
    }
    seq2seq_dict.update(pretrain_dict)
    seq2seq.load_state_dict(seq2seq_dict)  #load
    print('Loading ' + model_file)

    seq2seq.eval()
    total_loss_t = 0
    start_t = time.time()
    for num, (test_index, test_in, test_in_len, test_out,
              test_domain) in enumerate(test_loader):
        lambd = LAMBD
        test_in, test_out = Variable(test_in, volatile=True).cuda(), Variable(
            test_out, volatile=True).cuda()
        test_domain = Variable(test_domain, volatile=True).cuda()
        output_t, attn_weights_t, out_domain_t = seq2seq(test_in,
                                                         test_out,
                                                         test_in_len,
                                                         lambd,
                                                         teacher_rate=False,
                                                         train=False)
        batch_count_n = writePredict(modelID, test_index, output_t, 'test')
        test_label = test_out.permute(1, 0)[1:].contiguous().view(-1)
        if LABEL_SMOOTH:
            loss_t = crit(log_softmax(output_t.view(-1, vocab_size)),
                          test_label)
        else:
            loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
                                     test_label,
                                     ignore_index=tokens['PAD_TOKEN'])

        total_loss_t += loss_t.data[0]
        if showAttn:
            global_index_t = 0
            for t_idx, t_in in zip(test_index, test_in):
                visualizeAttn(t_in.data[0], test_in_len[0],
                              [j[global_index_t] for j in attn_weights_t],
                              modelID, batch_count_n[global_index_t],
                              'test_' + t_idx.split(',')[0])
                global_index_t += 1

    total_loss_t /= (num + 1)
    writeLoss(total_loss_t, 'test')
    print('       TEST loss=%.3f, time=%.3f' %
          (total_loss_t, time.time() - start_t))
Ejemplo n.º 5
0
def test(test_loader, modelID, showAttn=True):
    encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP,
                      FLIP).to(device)
    decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention,
                      TRADEOFF_CONTEXT_EMBED).to(device)
    seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).to(device)
    model_file = 'save_weights/seq2seq-' + str(modelID) + '.model'
    print('Loading ' + model_file)
    seq2seq.load_state_dict(torch.load(model_file))  #load

    seq2seq.eval()
    total_loss_t = 0
    start_t = time.time()
    with torch.no_grad():
        for num, (test_index, test_in, test_in_len,
                  test_out) in enumerate(test_loader):
            #test_in = test_in.unsqueeze(1)
            test_in, test_out = test_in.to(device), test_out.to(device)
            if test_in.requires_grad or test_out.requires_grad:
                print(
                    'ERROR! test_in, test_out should have requires_grad=False')
            output_t, attn_weights_t = seq2seq(test_in,
                                               test_out,
                                               test_in_len,
                                               teacher_rate=False,
                                               train=False)
            batch_count_n = writePredict(modelID, test_index, output_t, 'test')
            test_label = test_out.permute(1, 0)[1:].reshape(-1)
            #loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
            #                        test_label, ignore_index=tokens['PAD_TOKEN'])
            #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label)
            if LABEL_SMOOTH:
                loss_t = crit(log_softmax(output_t.reshape(-1, vocab_size)),
                              test_label)
            else:
                loss_t = F.cross_entropy(output_t.reshape(-1, vocab_size),
                                         test_label,
                                         ignore_index=tokens['PAD_TOKEN'])

            total_loss_t += loss_t.item()

            if showAttn:
                global_index_t = 0
                for t_idx, t_in in zip(test_index, test_in):
                    visualizeAttn(t_in.detach()[0], test_in_len[0],
                                  [j[global_index_t] for j in attn_weights_t],
                                  modelID, batch_count_n[global_index_t],
                                  'test_' + t_idx.split(',')[0])
                    global_index_t += 1

        total_loss_t /= (num + 1)
        writeLoss(total_loss_t, 'test')
        print('    TEST loss=%.3f, time=%.3f' %
              (total_loss_t, time.time() - start_t))
def train(train_loader, seq2seq, opt, teacher_rate, epoch, lambd):
    seq2seq.train()
    total_loss = 0
    total_loss_d = 0
    for num, (train_index, train_in, train_in_len, train_out,
              train_domain) in enumerate(train_loader):
        train_in, train_out = Variable(train_in).cuda(), Variable(
            train_out).cuda()
        train_domain = Variable(train_domain).cuda()
        output, attn_weights, out_domain = seq2seq(
            train_in,
            train_out,
            train_in_len,
            lambd,
            teacher_rate=teacher_rate,
            train=True)  # (100-1, 32, 62+1)
        batch_count_n = writePredict(epoch, train_index, output, 'train')
        train_label = train_out.permute(1, 0)[1:].contiguous().view(
            -1)  #remove<GO>
        output_l = output.view(-1, vocab_size)  # remove last <EOS>

        if VISUALIZE_TRAIN:
            if 'e02-074-03-00,191' in train_index:
                b = train_index.tolist().index('e02-074-03-00,191')
                visualizeAttn(train_in.data[b, 0], train_in_len[0],
                              [j[b] for j in attn_weights], epoch,
                              batch_count_n[b], 'train_e02-074-03-00')

        if LABEL_SMOOTH:
            loss = crit(log_softmax(output_l.view(-1, vocab_size)),
                        train_label)
        else:
            loss = F.cross_entropy(output_l.view(-1, vocab_size),
                                   train_label,
                                   ignore_index=tokens['PAD_TOKEN'])

        loss2 = F.cross_entropy(out_domain, train_domain)
        loss2 = ALPHA * loss2
        loss_total = loss + loss2
        opt.zero_grad()
        loss_total.backward()
        opt.step()
        total_loss += loss.data[0]
        total_loss_d += loss2.data[0]

    total_loss /= (num + 1)
    total_loss_d /= (num + 1)
    return total_loss, total_loss_d
Ejemplo n.º 7
0
def train(train_loader, seq2seq, opt, teacher_rate, epoch):
    seq2seq.train()
    total_loss = 0
    for num, (train_index, train_in, train_in_len,
              train_out) in enumerate(train_loader):
        #train_in = train_in.unsqueeze(1)
        train_in, train_out = Variable(train_in).cuda(), Variable(
            train_out).cuda()
        output, attn_weights = seq2seq(train_in,
                                       train_out,
                                       train_in_len,
                                       teacher_rate=teacher_rate,
                                       train=True)  # (100-1, 32, 62+1)
        batch_count_n = writePredict(epoch, train_index, output, 'train')
        train_label = train_out.permute(1, 0)[1:].contiguous().view(
            -1)  #remove<GO>
        output_l = output.view(-1, vocab_size)  # remove last <EOS>

        if VISUALIZE_TRAIN:
            if 'e02-074-03-00,191' in train_index:
                b = train_index.tolist().index('e02-074-03-00,191')
                visualizeAttn(train_in.data[b, 0], train_in_len[0],
                              [j[b] for j in attn_weights], epoch,
                              batch_count_n[b], 'train_e02-074-03-00')

        #loss = F.cross_entropy(output_l.view(-1, vocab_size),
        #                       train_label, ignore_index=tokens['PAD_TOKEN'])
        #loss = loss_label_smoothing(output_l.view(-1, vocab_size), train_label)
        if LABEL_SMOOTH:
            loss = crit(log_softmax(output_l.view(-1, vocab_size)),
                        train_label)
        else:
            loss = F.cross_entropy(output_l.view(-1, vocab_size),
                                   train_label,
                                   ignore_index=tokens['PAD_TOKEN'])
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.item()
        print(f'Batch {num} loss: {loss.item()}')

    total_loss /= (num + 1)
    return total_loss
Ejemplo n.º 8
0
def test(test_loader, modelID, showAttn=True):
    attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
    enc = Encoder(HEIGHT, WIDTH, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT).cuda()
    dec = Decoder(vocab_size, EMBEDDING_SIZE, ENC_HID_DIM, DEC_HID_DIM,
                  DEC_DROPOUT, attn).cuda()
    model = Seq2Seq(enc, dec, output_max_len, vocab_size).cuda()
    model_file = 'save_weights/seq2seq-' + str(modelID) + '.model'
    print('Loading ' + model_file)
    model.load_state_dict(torch.load(model_file))  #load

    model.eval()
    total_loss_t = 0
    start_t = time.time()
    with torch.no_grad():
        for num, (test_index, test_in, test_in_len,
                  test_out) in enumerate(test_loader):
            #test_in = test_in.unsqueeze(1)
            test_in, test_out = Variable(test_in).cuda(), Variable(
                test_out).cuda()
            output_t = model(test_in,
                             test_out,
                             test_in_len,
                             teacher_rate=False,
                             train=False)
            batch_count_n = writePredict(modelID, test_index, output_t, 'test')
            #writePredict_beam2(modelID, test_index, output_t, 'test')

            test_label = test_out.permute(1, 0)[1:].contiguous().view(
                -1)  #remove<GO>
            output_t = output_t.view(
                -1, vocab_size
            )  #torch.Size([batch, 94, 83]) it means there are total 94 outputs, for every output we have 83 choices

            loss = F.cross_entropy(output_t,
                                   test_label,
                                   ignore_index=tokens['PAD_TOKEN'])
            total_loss_t += loss.item()

        total_loss_t /= (num + 1)
        #writeLoss(total_loss_t, 'test')
        print('    TEST loss=%.3f' % (total_loss_t))
Ejemplo n.º 9
0
def valid(valid_loader, seq2seq, epoch):
    seq2seq.eval()
    total_loss_t = 0
    with torch.no_grad():
        for num, (test_index, test_in, test_in_len,
                  test_out) in enumerate(valid_loader):
            #test_in = test_in.unsqueeze(1)
            test_in, test_out = test_in.to(device), test_out.to(device)
            if test_in.requires_grad or test_out.requires_grad:
                print(
                    'ERROR! test_in, test_out should have requires_grad=False')
            output_t, attn_weights_t = seq2seq(test_in,
                                               test_out,
                                               test_in_len,
                                               teacher_rate=False,
                                               train=False)
            batch_count_n = writePredict(epoch, test_index, output_t, 'valid')
            test_label = test_out.permute(1, 0)[1:].reshape(-1)
            #loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
            #                         test_label, ignore_index=tokens['PAD_TOKEN'])
            #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label)
            if LABEL_SMOOTH:
                loss_t = crit(log_softmax(output_t.reshape(-1, vocab_size)),
                              test_label)
            else:
                loss_t = F.cross_entropy(output_t.reshape(-1, vocab_size),
                                         test_label,
                                         ignore_index=tokens['PAD_TOKEN'])

            total_loss_t += loss_t.item()

            if 'n04-015-00-01,171' in test_index:
                b = test_index.tolist().index('n04-015-00-01,171')
                visualizeAttn(test_in.detach()[b, 0], test_in_len[0],
                              [j[b] for j in attn_weights_t], epoch,
                              batch_count_n[b], 'valid_n04-015-00-01')
        total_loss_t /= (num + 1)
    return total_loss_t