def eval_reconstruction_with_rouge(autoencoder, idx2word, text_criterion,
                                   imgseq_criterion, data_iter, device):
    print("=================Eval======================")
    autoencoder.eval()
    step = 0
    avg_loss = 0.
    avg_text_loss = 0.
    avg_imgseq_loss = 0.
    rouge_1 = 0.
    rouge_2 = 0.
    for text_batch, imgseq_batch in tqdm(data_iter):
        torch.cuda.empty_cache()
        with torch.no_grad():
            text_feature = Variable(text_batch).to(device)
            imgseq_feature = Variable(imgseq_batch).to(device)
        text_prob, imgseq_feature_hat = autoencoder(text_feature,
                                                    imgseq_feature)
        _, predict_index = torch.max(text_prob, 2)
        original_sentences = [
            util.transform_idx2word(sentence, idx2word=idx2word)
            for sentence in text_feature.detach().cpu().numpy()
        ]
        predict_sentences = [
            util.transform_idx2word(sentence, idx2word=idx2word)
            for sentence in predict_index.detach().cpu().numpy()
        ]
        r1, r2 = calc_rouge(original_sentences, predict_sentences)
        rouge_1 += r1 / len(text_batch)
        rouge_2 += r2 / len(text_batch)
        text_loss = text_criterion(text_prob.transpose(1, 2), text_feature)
        avg_text_loss += text_loss.detach().item()
        imgseq_loss = imgseq_criterion(imgseq_feature_hat, imgseq_feature)
        avg_imgseq_loss += imgseq_loss.detach().item()
        loss = text_loss + imgseq_loss
        del text_loss, imgseq_loss
        avg_loss += loss.detach().item()
        step = step + 1
        del text_feature, text_prob, imgseq_feature, imgseq_feature_hat, loss, _, predict_index
    avg_text_loss = avg_text_loss / step
    avg_imgseq_loss = avg_imgseq_loss / step
    avg_loss = avg_loss / step
    rouge_1 = rouge_1 / step
    rouge_2 = rouge_2 / step
    print("===============================================================")
    autoencoder.train()

    return avg_text_loss, avg_imgseq_loss, avg_loss, rouge_1, rouge_2
Beispiel #2
0
def eval_reconstruction_with_rouge(autoencoder, idx2word, criterion, data_iter,
                                   device):
    print("=================Eval======================")
    autoencoder.eval()
    step = 0
    avg_loss = 0.
    rouge_1 = 0.
    rouge_2 = 0.
    for batch in tqdm(data_iter):
        torch.cuda.empty_cache()
        with torch.no_grad():
            feature = Variable(batch).to(device)
        prob = autoencoder(feature)
        _, predict_index = torch.max(prob, 2)
        original_sentences = [
            util.transform_idx2word(sentence, idx2word=idx2word)
            for sentence in feature.detach().cpu().numpy()
        ]
        predict_sentences = [
            util.transform_idx2word(sentence, idx2word=idx2word)
            for sentence in predict_index.detach().cpu().numpy()
        ]
        r1, r2 = calc_rouge(original_sentences, predict_sentences)
        rouge_1 += r1 / len(batch)
        rouge_2 += r2 / len(batch)
        loss = criterion(prob.transpose(1, 2), feature)
        avg_loss += loss.detach().item()
        step = step + 1
        del feature, prob, loss, _, predict_index
    avg_loss = avg_loss / step
    rouge_1 = rouge_1 / step
    rouge_2 = rouge_2 / step
    print("===============================================================")
    autoencoder.train()

    return avg_loss, rouge_1, rouge_2
Beispiel #3
0
def train_reconstruction(args):
    device = torch.device(args.gpu)
    print("Loading embedding model...")
    with open(
            os.path.join(CONFIG.DATASET_PATH, args.target_dataset,
                         'word_embedding.p'), "rb") as f:
        embedding_model = cPickle.load(f)
    with open(os.path.join(CONFIG.DATASET_PATH, args.target_dataset,
                           'word_idx.json'),
              "r",
              encoding='utf-8') as f:
        word_idx = json.load(f)
    print("Loading embedding model completed")
    print("Loading dataset...")
    train_dataset, val_dataset = load_text_data(args,
                                                CONFIG,
                                                word2idx=word_idx[1])
    print("Loading dataset completed")
    train_loader, val_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=args.shuffle),\
             DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    # t1 = max_sentence_len + 2 * (args.filter_shape - 1)
    t1 = CONFIG.MAX_SENTENCE_LEN
    t2 = int(math.floor(
        (t1 - args.filter_shape) / 2) + 1)  # "2" means stride size
    t3 = int(math.floor((t2 - args.filter_shape) / 2) + 1)
    args.t3 = t3
    embedding = nn.Embedding.from_pretrained(
        torch.FloatTensor(embedding_model))
    text_encoder = text_model.ConvolutionEncoder(embedding, t3,
                                                 args.filter_size,
                                                 args.filter_shape,
                                                 args.latent_size)
    text_decoder = text_model.DeconvolutionDecoder(embedding, args.tau, t3,
                                                   args.filter_size,
                                                   args.filter_shape,
                                                   args.latent_size, device)
    if args.resume:
        print("Restart from checkpoint")
        checkpoint = torch.load(os.path.join(CONFIG.CHECKPOINT_PATH,
                                             args.resume),
                                map_location=lambda storage, loc: storage)
        start_epoch = checkpoint['epoch']
        text_encoder.load_state_dict(checkpoint['text_encoder'])
        text_decoder.load_state_dict(checkpoint['text_decoder'])
    else:
        print("Start from initial")
        start_epoch = 0

    text_autoencoder = text_model.TextAutoencoder(text_encoder, text_decoder)
    criterion = nn.NLLLoss().to(device)
    text_autoencoder.to(device)

    optimizer = AdamW(text_autoencoder.parameters(),
                      lr=1.,
                      weight_decay=args.weight_decay,
                      amsgrad=True)
    step_size = args.half_cycle_interval * len(train_loader)
    clr = cyclical_lr(step_size,
                      min_lr=args.lr,
                      max_lr=args.lr * args.lr_factor)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, [clr])
    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
    exp = Experiment("Text autoencoder " + str(args.latent_size),
                     capture_io=False)

    for arg, value in vars(args).items():
        exp.param(arg, value)
    try:
        text_autoencoder.train()

        for epoch in range(start_epoch, args.epochs):
            print("Epoch: {}".format(epoch))
            for steps, batch in enumerate(train_loader):
                torch.cuda.empty_cache()
                feature = Variable(batch).to(device)
                optimizer.zero_grad()
                prob = text_autoencoder(feature)
                loss = criterion(prob.transpose(1, 2), feature)
                loss.backward()
                optimizer.step()
                scheduler.step()

                if (steps * args.batch_size) % args.log_interval == 0:
                    input_data = feature[0]
                    single_data = prob[0]
                    _, predict_index = torch.max(single_data, 1)
                    input_sentence = util.transform_idx2word(
                        input_data.detach().cpu().numpy(),
                        idx2word=word_idx[0])
                    predict_sentence = util.transform_idx2word(
                        predict_index.detach().cpu().numpy(),
                        idx2word=word_idx[0])
                    print("Epoch: {} at {} lr: {}".format(
                        epoch, str(datetime.datetime.now()),
                        str(scheduler.get_lr())))
                    print("Steps: {}".format(steps))
                    print("Loss: {}".format(loss.detach().item()))
                    print("Input Sentence:")
                    print(input_sentence)
                    print("Output Sentence:")
                    print(predict_sentence)
                    del input_data, single_data, _, predict_index
                del feature, prob, loss

            exp.log("\nEpoch: {} at {} lr: {}".format(
                epoch, str(datetime.datetime.now()), str(scheduler.get_lr())))
            _avg_loss, _rouge_1, _rouge_2 = eval_reconstruction_with_rouge(
                text_autoencoder, word_idx[0], criterion, val_loader, device)
            exp.log("\nEvaluation - loss: {}  Rouge1: {} Rouge2: {}".format(
                _avg_loss, _rouge_1, _rouge_2))

            util.save_models(
                {
                    'epoch': epoch + 1,
                    'text_encoder': text_encoder.state_dict(),
                    'text_decoder': text_decoder.state_dict(),
                    'avg_loss': _avg_loss,
                    'Rouge1:': _rouge_1,
                    'Rouge2': _rouge_2,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()
                }, CONFIG.CHECKPOINT_PATH,
                "text_autoencoder_" + str(args.latent_size))

        print("Finish!!!")

    finally:
        exp.end()