def test(test_loader, modelID, showAttn=True):
    encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP,
                      FLIP).cuda()
    decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention,
                      TRADEOFF_CONTEXT_EMBED).cuda()
    seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).cuda()
    model_file = 'save_weights/seq2seq-' + str(modelID) + '.model'
    pretrain_dict = torch.load(model_file)
    seq2seq_dict = seq2seq.state_dict()
    pretrain_dict = {
        k: v
        for k, v in pretrain_dict.items() if k in seq2seq_dict
    }
    seq2seq_dict.update(pretrain_dict)
    seq2seq.load_state_dict(seq2seq_dict)  #load
    print('Loading ' + model_file)

    seq2seq.eval()
    total_loss_t = 0
    start_t = time.time()
    for num, (test_index, test_in, test_in_len, test_out,
              test_domain) in enumerate(test_loader):
        lambd = LAMBD
        test_in, test_out = Variable(test_in, volatile=True).cuda(), Variable(
            test_out, volatile=True).cuda()
        test_domain = Variable(test_domain, volatile=True).cuda()
        output_t, attn_weights_t, out_domain_t = seq2seq(test_in,
                                                         test_out,
                                                         test_in_len,
                                                         lambd,
                                                         teacher_rate=False,
                                                         train=False)
        batch_count_n = writePredict(modelID, test_index, output_t, 'test')
        test_label = test_out.permute(1, 0)[1:].contiguous().view(-1)
        if LABEL_SMOOTH:
            loss_t = crit(log_softmax(output_t.view(-1, vocab_size)),
                          test_label)
        else:
            loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
                                     test_label,
                                     ignore_index=tokens['PAD_TOKEN'])

        total_loss_t += loss_t.data[0]
        if showAttn:
            global_index_t = 0
            for t_idx, t_in in zip(test_index, test_in):
                visualizeAttn(t_in.data[0], test_in_len[0],
                              [j[global_index_t] for j in attn_weights_t],
                              modelID, batch_count_n[global_index_t],
                              'test_' + t_idx.split(',')[0])
                global_index_t += 1

    total_loss_t /= (num + 1)
    writeLoss(total_loss_t, 'test')
    print('       TEST loss=%.3f, time=%.3f' %
          (total_loss_t, time.time() - start_t))
Ejemplo n.º 2
0
def test(test_loader, modelID, showAttn=True):
    encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP,
                      FLIP).to(device)
    decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention,
                      TRADEOFF_CONTEXT_EMBED).to(device)
    seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).to(device)
    model_file = 'save_weights/seq2seq-' + str(modelID) + '.model'
    print('Loading ' + model_file)
    seq2seq.load_state_dict(torch.load(model_file))  #load

    seq2seq.eval()
    total_loss_t = 0
    start_t = time.time()
    with torch.no_grad():
        for num, (test_index, test_in, test_in_len,
                  test_out) in enumerate(test_loader):
            #test_in = test_in.unsqueeze(1)
            test_in, test_out = test_in.to(device), test_out.to(device)
            if test_in.requires_grad or test_out.requires_grad:
                print(
                    'ERROR! test_in, test_out should have requires_grad=False')
            output_t, attn_weights_t = seq2seq(test_in,
                                               test_out,
                                               test_in_len,
                                               teacher_rate=False,
                                               train=False)
            batch_count_n = writePredict(modelID, test_index, output_t, 'test')
            test_label = test_out.permute(1, 0)[1:].reshape(-1)
            #loss_t = F.cross_entropy(output_t.view(-1, vocab_size),
            #                        test_label, ignore_index=tokens['PAD_TOKEN'])
            #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label)
            if LABEL_SMOOTH:
                loss_t = crit(log_softmax(output_t.reshape(-1, vocab_size)),
                              test_label)
            else:
                loss_t = F.cross_entropy(output_t.reshape(-1, vocab_size),
                                         test_label,
                                         ignore_index=tokens['PAD_TOKEN'])

            total_loss_t += loss_t.item()

            if showAttn:
                global_index_t = 0
                for t_idx, t_in in zip(test_index, test_in):
                    visualizeAttn(t_in.detach()[0], test_in_len[0],
                                  [j[global_index_t] for j in attn_weights_t],
                                  modelID, batch_count_n[global_index_t],
                                  'test_' + t_idx.split(',')[0])
                    global_index_t += 1

        total_loss_t /= (num + 1)
        writeLoss(total_loss_t, 'test')
        print('    TEST loss=%.3f, time=%.3f' %
              (total_loss_t, time.time() - start_t))
Ejemplo n.º 3
0
def main(train_loader, valid_loader, test_loader):
    encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP,
                      FLIP).cuda()
    decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention,
                      TRADEOFF_CONTEXT_EMBED).cuda()
    seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).cuda()
    if CurriculumModelID > 0:
        model_file = 'save_weights/seq2seq-' + str(
            CurriculumModelID) + '.model'
        #model_file = 'save_weights/words/seq2seq-' + str(CurriculumModelID) +'.model'
        print('Loading ' + model_file)
        seq2seq.load_state_dict(torch.load(model_file))  #load
    opt = optim.Adam(seq2seq.parameters(), lr=learning_rate)
    #opt = optim.SGD(seq2seq.parameters(), lr=learning_rate, momentum=0.9)
    #opt = optim.RMSprop(seq2seq.parameters(), lr=learning_rate, momentum=0.9)

    #scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=1)
    scheduler = optim.lr_scheduler.MultiStepLR(opt,
                                               milestones=lr_milestone,
                                               gamma=lr_gamma)
    epochs = 5000000
    if EARLY_STOP_EPOCH is not None:
        min_loss = 1e3
        min_loss_index = 0
        min_loss_count = 0

    if CurriculumModelID > 0 and WORD_LEVEL:
        start_epoch = CurriculumModelID + 1
        for i in range(start_epoch):
            scheduler.step()
    else:
        start_epoch = 0

    for epoch in range(start_epoch, epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        teacher_rate = teacher_force_func(epoch) if TEACHER_FORCING else False
        start = time.time()
        loss = train(train_loader, seq2seq, opt, teacher_rate, epoch)
        writeLoss(loss, 'train')
        print('epoch %d/%d, loss=%.3f, lr=%.8f, teacher_rate=%.3f, time=%.3f' %
              (epoch, epochs, loss, lr, teacher_rate, time.time() - start))

        if epoch % MODEL_SAVE_EPOCH == 0:
            folder_weights = 'save_weights'
            if not os.path.exists(folder_weights):
                os.makedirs(folder_weights)
            torch.save(seq2seq.state_dict(),
                       folder_weights + '/seq2seq-%d.model' % epoch)

        start_v = time.time()
        loss_v = valid(valid_loader, seq2seq, epoch)
        writeLoss(loss_v, 'valid')
        print('  Valid loss=%.3f, time=%.3f' % (loss_v, time.time() - start_v))

        if EARLY_STOP_EPOCH is not None:
            gt = 'RWTH_partition/RWTH.iam_word_gt_final.valid.thresh'
            decoded = 'pred_logs/valid_predict_seq.' + str(epoch) + '.log'
            res_cer = sub.Popen(['./tasas_cer.sh', gt, decoded],
                                stdout=sub.PIPE)
            res_cer = res_cer.stdout.read().decode('utf8')
            loss_v = float(res_cer) / 100
            if loss_v < min_loss:
                min_loss = loss_v
                min_loss_index = epoch
                min_loss_count = 0
            else:
                min_loss_count += 1
            if min_loss_count >= EARLY_STOP_EPOCH:
                print('Early Stopping at: %d. Best epoch is: %d' %
                      (epoch, min_loss_index))
                return min_loss_index
def main(all_data_loader_func):
    encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP,
                      FLIP).cuda()
    decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention,
                      TRADEOFF_CONTEXT_EMBED).cuda()
    seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).cuda()
    if CurriculumModelID > 0:
        model_file = 'save_weights/seq2seq-' + str(
            CurriculumModelID) + '.model'
        print('Loading ' + model_file)
        pretrain_dict = torch.load(model_file)
        seq2seq_dict = seq2seq.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items() if k in seq2seq_dict
        }
        seq2seq_dict.update(pretrain_dict)
        seq2seq.load_state_dict(seq2seq_dict)  #load
    opt = optim.Adam(seq2seq.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.MultiStepLR(opt,
                                               milestones=lr_milestone,
                                               gamma=lr_gamma)
    epochs = 5000
    if EARLY_STOP_EPOCH is not None:
        min_loss = 1e3
        min_loss_index = 0
        min_loss_count = 0

    if CurriculumModelID > 0:
        start_epoch = CurriculumModelID + 1
        for i in range(start_epoch):
            scheduler.step()
    else:
        start_epoch = 0

    for epoch in range(start_epoch, epochs):
        # each epoch, random sample training set to be balanced with unlabeled test set
        train_loader, valid_loader, test_loader = all_data_loader_func()
        scheduler.step()
        lr = scheduler.get_lr()[0]
        teacher_rate = teacher_force_func(epoch) if TEACHER_FORCING else False
        start = time.time()

        lambd = return_lambda(epoch)

        loss, loss_d = train(train_loader, seq2seq, opt, teacher_rate, epoch,
                             lambd)
        writeLoss(loss, 'train')
        writeLoss(loss_d, 'domain_train')
        print(
            'epoch %d/%d, loss=%.3f, domain_loss=%.3f, lr=%.6f, teacher_rate=%.3f, lambda_pau=%.3f, time=%.3f'
            % (epoch, epochs, loss, loss_d, lr, teacher_rate, lambd,
               time.time() - start))

        if epoch % MODEL_SAVE_EPOCH == 0:
            folder_weights = 'save_weights'
            if not os.path.exists(folder_weights):
                os.makedirs(folder_weights)
            torch.save(seq2seq.state_dict(),
                       folder_weights + '/seq2seq-%d.model' % epoch)

        start_v = time.time()
        loss_v, loss_v_d = valid(valid_loader, seq2seq, epoch)
        writeLoss(loss_v, 'valid')
        writeLoss(loss_v_d, 'domain_valid')
        print('      Valid loss=%.3f, domain_loss=%.3f, time=%.3f' %
              (loss_v, loss_v_d, time.time() - start_v))

        test(test_loader, epoch, False)  #~~~~~~

        if EARLY_STOP_EPOCH is not None:
            gt = loadData.GT_TE
            decoded = 'pred_logs/valid_predict_seq.' + str(epoch) + '.log'
            res_cer = sub.Popen(['./tasas_cer.sh', gt, decoded],
                                stdout=sub.PIPE)
            res_cer = res_cer.stdout.read().decode('utf8')
            loss_v = float(res_cer) / 100
            if loss_v < min_loss:
                min_loss = loss_v
                min_loss_index = epoch
                min_loss_count = 0
            else:
                min_loss_count += 1
            if min_loss_count >= EARLY_STOP_EPOCH:
                print('Early Stopping at: %d. Best epoch is: %d' %
                      (epoch, min_loss_index))
                return min_loss_index