예제 #1
0
def evaluate():

    # DataLoader
    loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TEST',
        transform=transforms.Compose([normalize])),
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=num_workers,
                                         pin_memory=True)

    wrong_tot = 0
    references = list()
    hypotheses = list()

    for i, (imgs, caps, caplens, allcaps) in enumerate(
            tqdm(loader,
                 desc="EVALUATING AT BEAM SIZE " + str(beam_size),
                 ascii=True)):

        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)
        batch_predictions, batch_references, wrong = evaluate_batch(
            encoder, decoder, imgs, allcaps)
        wrong_tot += wrong
        assert len(batch_references) == len(batch_predictions)
        references.extend(batch_references)
        hypotheses.extend(batch_predictions)

    bleu4 = corpus_bleu(references, hypotheses)
    print("\nBLEU-4 score @ beam size of %d is %.4f." % (beam_size, bleu4))
예제 #2
0
if __name__ == "__main__":
    # dataset and dataloader config
    dataset_guide_path = './data_config/ValDatasetGuide.json'
    word_map_path = './data_config/SelectGloveWordMap.json'

    # net checkpoint
    checkpoint = './checkpoints/checkpoint_best.pth'

    device = torch.device("cuda:0")
    num_workers = 4
    batch_size = 16
    pin_memory = False

    # Build dataset and dataloader
    dataset = CaptionDataset(dataset_guide_path,
                             word_map_path,
                             remove_invalid=True,
                             test_only_n_samples=None)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers,
                            collate_fn=dataset.collate_fn,
                            pin_memory=pin_memory)

    checkpoint = torch.load(checkpoint)
    net = checkpoint['model']

    with open(word_map_path, 'r') as j:
        word_map = json.load(j)
    beam_searcher = models.BeamSearcher(word_map['<bos>'],
                                        word_map['<eos>'],
예제 #3
0
def main():
    """Training and validation."""

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(
            attention_dim=attention_dim,
            embed_dim=emb_dim,
            decoder_dim=decoder_dim,
            vocab_size=len(word_map),
            dropout=dropout
        )
        decoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, decoder.parameters()),
            lr=decoder_lr
        )
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr
        ) if fine_tune_encoder else None
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(
                params=filter(lambda p: p.requires_grad, encoder.parameters()),
                lr=encoder_lr
            )

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True
    )

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(
            train_loader=train_loader,
            encoder=encoder,
            decoder=decoder,
            criterion=criterion,
            encoder_optimizer=encoder_optimizer,
            decoder_optimizer=decoder_optimizer,
            epoch=epoch
        )

        # One epoch's validation
        recent_bleu4 = validate(
            val_loader=val_loader,
            encoder=encoder,
            decoder=decoder,
            criterion=criterion
        )

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(
            data_name, epoch, epochs_since_improvement, encoder, decoder,
            encoder_optimizer, decoder_optimizer, recent_bleu4, is_best
        )
예제 #4
0
from datasets import CaptionDataset
from tqdm import tqdm
import os
import pickle
import json
"""train = CaptionDataset('dataset/output',
                         'coco_5_cap_per_img_5_min_word_freq',
                         'TRAIN',
                         None)

val = CaptionDataset('dataset/output',
                         'coco_5_cap_per_img_5_min_word_freq',
                         'VAL',
                         None)"""

test = CaptionDataset('dataset/output', 'coco_5_cap_per_img_5_min_word_freq',
                      'TEST', None)

idx2id = {}
with (open(os.path.join('dataset', 'output', 'TEST_ids.txt'), 'r')) as f:
    for i, line in enumerate(f):
        values = line.rstrip().split()
        idx2id[i] = int(values[0])
"""id2idx = {value : key for (key, value) in idx2id.items()}"""

captions = pickle.load(open('captionsOriginal.pkl', 'rb'))

word_map_file = 'dataset/output/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json'  # word map, ensure it's the same the data was encoded with and the model was trained with
# Load word map (word2ix)
with open(word_map_file, 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}
workers = 0  # Workers for loading the dataset. Need this to be 0 for windows, change to sutable value for other os.
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 1800  # dimension of decoder RNN
dropout = 0.5  # Dropout rate
decoder_lr = 2 * 1e-3  # Decoder learning rate
numepochs = 100  # Number of epochs
load = False  ## Make this false when you don't want load a checkpoint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(\
        CaptionDataset( transform=transforms.Compose([normalize])),\
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
# Note that the resize is already done in the encoder, so no need to do it here again
if load:
    # Load the model from checkpoints
    checkpoints = torch.load('checkpoint_d')
    encoder = checkpoints['encoder']
    decoder = checkpoints['decoder']
    decoder_optimizer = checkpoints['decoder_optimizer']
    epoch = checkpoints['epoch']
    decoder_lr = decoder_lr * pow(0.8, epoch // 5)
    for param_group in decoder_optimizer.param_groups:
        param_group['lr'] = decoder_lr
else:
    epoch = 0
    encoder = Encoder()
예제 #6
0
    # dataset and dataloader config
    train_dataset_guide_path = './data_config/TrainDatasetGuide.json'
    valid_dataset_guide_path = './data_config/ValDatasetGuide.json'
    word_map_path            = './data_config/SelectGloveWordMap.json'

    # whether to use glove
    glove_tensor_path        = './data_config/SelectGloveTensor.pth'
    use_glove                = True

    train_num_workers = 4
    valid_num_workers = 4
    train_batch_size  = 32
    valid_batch_size  = 32

    # Build dataset and dataloader
    train_dataset = CaptionDataset(train_dataset_guide_path, word_map_path, remove_invalid=True)
    valid_dataset = CaptionDataset(valid_dataset_guide_path, word_map_path, remove_invalid=True)
    vocab_size    = train_dataset.vocab_size

    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=train_num_workers, collate_fn=train_dataset.collate_fn)
    valid_dataloader = DataLoader(valid_dataset, batch_size=valid_batch_size, shuffle=False, num_workers=valid_num_workers, collate_fn=valid_dataset.collate_fn)

    # ############################# Debug ########################################
    # # Build dataset and dataloader
    # train_dataset = CaptionDataset(train_dataset_guide_path, word_map_path, remove_invalid=True, test_only_n_samples=512)
    # valid_dataset = CaptionDataset(train_dataset_guide_path, word_map_path, remove_invalid=True, test_only_n_samples=512)
    # vocab_size    = train_dataset.vocab_size

    # train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=train_num_workers, collate_fn=train_dataset.collate_fn)
    # valid_dataloader = DataLoader(valid_dataset, batch_size=valid_batch_size, shuffle=False, num_workers=valid_num_workers, collate_fn=valid_dataset.collate_fn)
    # ############################################################################
예제 #7
0
def main():
    """
    训练和验证
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # 读入词典
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # 初始化/加载模型
    if checkpoint is None:
        decoder = DecoderWithAttention(hidden_size=hidden_size,
                                       vocab_size=len(word_map),
                                       attention_dim=attention_dim,
                                       embed_size=emb_dim,
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=decoder.parameters(),
                                             lr=decoder_lr,
                                             betas=(0.8, 0.999))
        encoder = Encoder(hidden_size=hidden_size,
                          embed_size=emb_dim,
                          dropout=dropout)
        # 是否微调
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr,
            betas=(0.8, 0.999)) if fine_tune_encoder else None

    else:
        #载入checkpoint
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            # 如果此时要开始微调,需要定义优化器
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr,
                                                 betas=(0.8, 0.999))

    # 移动到GPU
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])  #ImageNet
    # pin_memory = True 驻留内存,不换进换出
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        if epoch > 15:
            adjust_learning_rate(decoder_optimizer, epoch)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, epoch)

        # Early Stopping if the validation score does not imporive for 6 consecutive epochs
        if epochs_since_improvement == 6:
            break

        # 一个epoch的训练
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              vocab_size=len(word_map))

        # 一个epoch的验证
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # 检查是否有提升
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # 保存模型
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
예제 #8
0
## Read model
model_ckpt = torch.load(checkpoint_path)
decoder = model_ckpt["decoder"].to(device)
encoder = model_ckpt["encoder"].to(device)
decoder.eval()
encoder.eval()
## Read word_map
with open(word_map_path, 'r') as j:
    word_map = json.load(j)
inv_word_map = {v: k for k, v in word_map.items()}
vocab_size = len(word_map)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
val_set = CaptionDataset(data_folder=data_folder,
                         data_name=dataset_name,
                         split="VAL",
                         transform=transforms.Compose([normalize]))
val_loader = DataLoader(val_set,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=workers,
                        pin_memory=True)


# Define validation process
def validation(beam_size):
    """
    Evaluation Process
    
    :param beam_size: beam size at which to generate captions for evaluation
    :return: BLEU-4 score
예제 #9
0
from datasets import CaptionDataset
from tqdm import tqdm
import os
import pickle
import json
from utils import get_word_synonyms
from nltk import stem
from pycocotools.coco import COCO
import numpy as np
import matplotlib.pyplot as plt
import re

train = CaptionDataset('dataset/output',
                       'coco_5_cap_per_img_5_min_word_freq',
                       'TRAIN',
                       None,
                       minimal=True)

train_annotations = COCO(
    os.path.join('dataset', 'annotations', 'instances_train2014.json'))
val_annotations = COCO(
    os.path.join('dataset', 'annotations', 'instances_val2014.json'))
idx2dataset = {}
with (open(os.path.join('dataset', 'output', 'TRAIN_ids.txt'), 'r')) as f:
    for i, line in enumerate(f):
        values = line.rstrip().split()
        idx2dataset[i] = values[1]

synonyms = get_word_synonyms()
stemmer = stem.snowball.PorterStemmer()
예제 #10
0
def evaluate(beam_size):
    """Evaluation

    Args:
        beam_size: beam size at which to generate captions for evaluation
    
    Returns:
        BLEU-4 score
    """
    # DataLoader
    loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TEST', transform=transforms.Compose([normalize])),
        batch_size=1, shuffle=True, num_workers=1, pin_memory=True
    )

    # TODO: Batched Beam Search
    # Therefore, do not use a batch_size greater than 1 - IMPORTANT!

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
    references = list()
    hypotheses = list()

    # For each image
    for i, (image, caps, caplens, allcaps) in enumerate(
        tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))
    ):

        k = beam_size

        # Move to GPU device, if available
        image = image.to(device)  # (1, 3, 256, 256)

        # Encode
        encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
        enc_image_size = encoder_out.size(1)
        encoder_dim = encoder_out.size(3)

        # Flatten encoding
        encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # We'll treat the problem as having a batch size of k
        encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

        # Tensor to store top k previous words at each step; now they're just <start>
        k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)

        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words  # (k, 1)

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h, c = decoder.init_hidden_state(encoder_out)

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:

            embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

            awe, _ = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

            gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
            awe = gate * awe

            h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

            scores = decoder.fc(h)  # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)

            # Add
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

            # Convert unrolled indices to actual indices of scores
            prev_word_inds = top_k_words / vocab_size  # (s)
            next_word_inds = top_k_words % vocab_size  # (s)

            # Add new words to sequences
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [
                ind for ind, next_word in enumerate(next_word_inds)
                if next_word != word_map['<end>']
            ]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            # Break if things have been going on too long
            if step > 50:
                break
            step += 1

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # References
        img_caps = allcaps[0].tolist()
        img_captions = list(map(
            lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}], img_caps
        ))  # remove <start> and pads
        references.append(img_captions)

        # Hypotheses
        hypotheses.append([w for w in seq if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}])

        assert len(references) == len(hypotheses)

    # Calculate BLEU-4 scores
    bleu4 = corpus_bleu(references, hypotheses)

    return bleu4
def evaluate(args):
    r"""
    Evaluation

    :param beam_size: beam size at which to generate captions for evaluation
    :return: BLEU-4 score
    """

    # DataLoader
    loader = DataLoader(CaptionDataset(args.data_folder,
                                       args.data_name,
                                       'TEST',
                                       transform=transforms.Compose(
                                           [normalize])),
                        batch_size=1,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

    need_tag = args.type in scn_based_model

    # Load word map (word2ix)
    with open(args.word_map, 'r') as j:
        word_map = json.load(j)
    rev_word_map = {v: k for k, v in word_map.items()}

    # Load tag map (word2ix)
    with open(args.tag_map, 'r') as j:
        tag_map = json.load(j)

    vocab_size = len(word_map)

    if need_tag:
        print('Load tagger checkpoint..')
        from models.encoders.tagger import EncoderTagger
        tagger_checkpoint = torch.load(
            args.model_tagger, map_location=lambda storage, loc: storage)

        print('Load tagger encoder...')
        encoder_tagger = EncoderTagger()
        encoder_tagger.load_state_dict(tagger_checkpoint['model_state_dict'])
        encoder_tagger = encoder_tagger.to(device)
        encoder_tagger.eval()

    print('Load caption checkpoint')
    caption_checkpoint = torch.load(args.model_caption,
                                    map_location=lambda storage, loc: storage)

    print('Load caption encoder..')
    from models.encoders.caption import EncoderCaption
    encoder_caption = EncoderCaption()

    encoder_caption.load_state_dict(
        caption_checkpoint['encoder_model_state_dict'])
    encoder_caption = encoder_caption.to(device)
    encoder_caption.eval()

    print('Load caption decoder..')
    decoder_caption = load_decoder(
        model_type=args.type,
        checkpoint=caption_checkpoint['decoder_model_state_dict'],
        vocab_size=vocab_size)
    decoder_caption.eval()

    print('=========================')

    # Preparing result
    references_temp = list()
    hypotheses = list()

    # For each image
    for i, (image, _, _, allcaps) in enumerate(
            tqdm(loader,
                 desc="EVALUATING AT BEAM SIZE " + str(args.beam_size))):

        k = args.beam_size

        # Move to GPU device, if available
        image = image.to(device)  # (1, 3, 256, 256)

        # Encode (1, enc_image_size, enc_image_size, encoder_dim)
        encoder_out = encoder_caption(image)

        # Tag (1, semantic_dim)
        tag_out = encoder_tagger(image)

        if need_tag:
            result = decoder_caption.sample(args.beam_size, word_map,
                                            encoder_out,
                                            tag_out)  # for scn-based model
        else:
            result = decoder_caption.sample(args.beam_size, word_map,
                                            encoder_out)

        try:
            seq, _ = result  # for attention-based model
        except:
            seq = result  # for scn only-based model

        # References
        img_caps = allcaps[0].tolist()
        img_captions = list(
            map(
                lambda c: ' '.join([
                    rev_word_map[w] for w in c if w not in {
                        word_map[start_token], word_map[end_token], word_map[
                            padding_token]
                    }
                ]), img_caps))  # remove <start> and pads
        references_temp.append(img_captions)

        # Hypotheses
        hypotheses.append(' '.join([
            rev_word_map[w] for w in seq if w not in {
                word_map[start_token], word_map[end_token],
                word_map[padding_token]
            }
        ]))

        assert len(references_temp) == len(hypotheses)

    # Calculate Metric scores

    # Modify array so NLGEval can read it
    references = [[] for x in range(len(references_temp[0]))]

    for refs in references_temp:
        for i in range(len(refs)):
            references[i].append(refs[i])

    current_time = round(time.time())

    os.makedirs(os.path.join('evaluation', current_time), exist_ok=True)

    # Creating instance of NLGEval
    n = NLGEval(no_skipthoughts=True, no_glove=True)

    with open(
            os.path.join(
                'evaluation', current_time,
                '{}_beam_{}_references.json'.format(args.type,
                                                    args.beam_size)),
            'w') as f:
        json.dump(references, f)
        f.close()

    with open(
            os.path.join(
                'evaluation', current_time,
                '{}_beam_{}_hypotheses.json'.format(args.type,
                                                    args.beam_size)),
            'w') as f:
        json.dump(hypotheses, f)
        f.close()

    scores = n.compute_metrics(ref_list=references, hyp_list=hypotheses)

    with open(
            os.path.join(
                'evaluation', current_time,
                '{}_beam_{}_scores.json'.format(args.type, args.beam_size)),
            'w') as f:
        json.dump(scores, f)
        f.close()

    return scores
예제 #12
0
def main():
    """
    Describe main process including train and validation.
    """

    global start_epoch, checkpoint, fine_tune_encoder, best_bleu4, epochs_since_improvement, word_map

    # Read word map
    word_map_path = os.path.join(data_folder,
                                 'WORDMAP_' + dataset_name + ".json")
    with open(word_map_path, 'r') as j:
        word_map = json.load(j)

    # Set checkpoint or read from checkpoint
    if checkpoint is None:  # No pretrained model, set model from beginning
        decoder = Decoder(embed_dim=embed_dim,
                          decoder_dim=decoder_dim,
                          vocab_size=len(word_map),
                          dropout=dropout_rate)
        decoder_param = filter(lambda p: p.requires_grad, decoder.parameters())
        for param in decoder_param:
            tensor0 = param.data
            dist.all_reduce(tensor0, op=dist.reduce_op.SUM)
            param.data = tensor0 / np.sqrt(np.float(num_nodes))
        decoder_optimizer = optim.Adam(params=decoder_param, lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_param = filter(lambda p: p.requires_grad, encoder.parameters())
        if fine_tune_encoder:
            for param in encoder_param:
                tensor0 = param.data
                dist.all_reduce(tensor0, op=dist.reduce_op.SUM)
                param.data = tensor0 / np.sqrt(np.float(num_nodes))
        encoder_optimizer = optim.Adam(
            params=encoder_param, lr=encoder_lr) if fine_tune_encoder else None
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        #decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        #encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    decoder = decoder.to(device)
    encoder = encoder.to(device)
    criterion = nn.CrossEntropyLoss()

    # Data loader
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_set = CaptionDataset(data_folder=h5data_folder,
                               data_name=dataset_name,
                               split="TRAIN",
                               transform=transforms.Compose([normalize]))
    val_set = CaptionDataset(data_folder=h5data_folder,
                             data_name=dataset_name,
                             split="VAL",
                             transform=transforms.Compose([normalize]))
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=workers,
                            pin_memory=True)

    total_start_time = datetime.datetime.now()
    print("Start the 1st epoch at: ", total_start_time)

    # Epoch
    for epoch in range(start_epoch, num_epochs):
        # Pre-check by epochs_since_improvement
        if epochs_since_improvement == 20:  # If there are 20 epochs that no improvements are achieved
            break
        if epochs_since_improvement % 8 == 0 and epochs_since_improvement > 0:
            adjust_learning_rate(decoder_optimizer)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer)

        # For every batch
        batch_time = AverageMeter()  # forward prop. + back prop. time
        data_time = AverageMeter()  # data loading time
        losses = AverageMeter()  # loss (per word decoded)
        top5accs = AverageMeter()  # top5 accuracy
        decoder.train()
        encoder.train()

        start = time.time()
        start_time = datetime.datetime.now(
        )  # Initialize start time for this epoch

        # TRAIN
        for j, (images, captions, caplens) in enumerate(train_loader):
            if fine_tune_encoder and (epoch - start_epoch > 0 or j > 10):
                for group in encoder_optimizer.param_groups:
                    for p in group['params']:
                        state = encoder_optimizer.state[p]
                        if (state['step'] >= 1024):
                            state['step'] = 1000

            if (epoch - start_epoch > 0 or j > 10):
                for group in decoder_optimizer.param_groups:
                    for p in group['params']:
                        state = decoder_optimizer.state[p]
                        if (state['step'] >= 1024):
                            state['step'] = 1000

            data_time.update(time.time() - start)

            images = images.to(device)
            captions = captions.to(device)
            caplens = caplens.to(device)
            # Forward
            enc_images = encoder(images)
            predictions, enc_captions, dec_lengths, sort_ind = decoder(
                enc_images, captions, caplens)

            # Define target as original captions excluding <start>
            target = enc_captions[:, 1:]  # (batch_size, max_caption_length-1)
            target, _ = pack_padded_sequence(
                target, dec_lengths, batch_first=True
            )  # Delete all paddings and concat all other parts
            predictions, _ = pack_padded_sequence(
                predictions, dec_lengths,
                batch_first=True)  # (batch_size, sum(dec_lengths))

            loss = criterion(predictions, target)

            # Backward
            decoder_optimizer.zero_grad()
            if encoder_optimizer is not None:
                encoder_optimizer.zero_grad()
            loss.backward()
            ## Clip gradients
            if grad_clip is not None:
                clip_gradient(decoder_optimizer, grad_clip)
                if encoder_optimizer is not None:
                    clip_gradient(encoder_optimizer, grad_clip)
            ## Update
            decoder_optimizer.step()
            if encoder_optimizer is not None:
                encoder_optimizer.step()

            # Update metrics (AverageMeter)
            acc_top5 = compute_accuracy(predictions, target, k=5)
            top5accs.update(acc_top5, sum(dec_lengths))
            losses.update(loss.item(), sum(dec_lengths))
            batch_time.update(time.time() - start)

            # Print current status
            if (j + 1) % print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Current Batch Time: {batch_time.val:.3f} (Average: {batch_time.avg:.3f})\t'
                    'Current Data Load Time: {data_time.val:.3f} (Average: {data_time.avg:.3f})\t'
                    'Current Loss: {loss.val:.4f} (Average: {loss.avg:.4f})\t'
                    'Current Top-5 Accuracy: {top5.val:.3f} (Average: {top5.avg:.3f})'
                    .format(epoch + 1,
                            j + 1,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            top5=top5accs))
                now_time = datetime.datetime.now()
                print("Epoch Training Time: ", now_time - start_time)
                print("Total Time: ", now_time - total_start_time)

            start = time.time()

        # VALIDATION
        decoder.eval()
        encoder.eval()

        batch_time = AverageMeter()  # forward prop. + back prop. time
        losses = AverageMeter()  # loss (per word decoded)
        top5accs = AverageMeter()  # top5 accuracy
        references = list(
        )  # references (true captions) for calculating BLEU-4 score
        hypotheses = list()  # hypotheses (predictions)

        start_time = datetime.datetime.now()

        for j, (images, captions, caplens, all_caps) in enumerate(val_loader):
            start = time.time()

            images = images.to(device)
            captions = captions.to(device)
            caplens = caplens.to(device)

            # Forward
            enc_images = encoder(images)
            predictions, enc_captions, dec_lengths, sort_ind = decoder(
                enc_images, captions, caplens)

            # Define target as original captions excluding <start>
            predictions_copy = predictions.clone()
            target = enc_captions[:, 1:]  # (batch_size, max_caption_length-1)
            target, _ = pack_padded_sequence(
                target, dec_lengths, batch_first=True
            )  # Delete all paddings and concat all other parts
            predictions, _ = pack_padded_sequence(
                predictions, dec_lengths,
                batch_first=True)  # (batch_size, sum(dec_lengths))

            loss = criterion(predictions, target)

            # Update metrics (AverageMeter)
            acc_top5 = compute_accuracy(predictions, target, k=5)
            top5accs.update(acc_top5, sum(dec_lengths))
            losses.update(loss.item(), sum(dec_lengths))
            batch_time.update(time.time() - start)

            # Print current status
            if (j + 1) % print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch + 1,
                        j,
                        len(val_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses,
                        top5=top5accs))
                now_time = datetime.datetime.now()
                print("Epoch Validation Time: ", now_time - start_time)
                print("Total Time: ", now_time - total_start_time)

            ## Store references (true captions), and hypothesis (prediction) for each image
            ## If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            ## references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # references
            all_caps = all_caps[sort_ind]
            for k in range(all_caps.shape[0]):
                img_caps = all_caps[k].tolist()
                img_captions = list(
                    map(
                        lambda c: [
                            w for w in c if w not in
                            {word_map["<start>"], word_map["<pad>"]}
                        ], img_caps))
                references.append(img_captions)

            # hypotheses
            _, preds = torch.max(predictions_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for i, p in enumerate(preds):
                temp_preds.append(preds[i][:dec_lengths[i]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        ## Compute BLEU-4 Scores
        #recent_bleu4 = corpus_bleu(references, hypotheses, emulate_multibleu=True)
        recent_bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'
            .format(loss=losses, top5=top5accs, bleu=recent_bleu4))

        # CHECK IMPROVEMENT
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement))
        else:
            epochs_since_improvement = 0

        # SAVE CHECKPOINT
        save_checkpoint(dataset_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
        print("Epoch {}, cost time: {}\n".format(epoch + 1,
                                                 now_time - total_start_time))
예제 #13
0
def beam_evaluate_trans(data_name, checkpoint_file, data_folder, beam_size,
                        outdir):
    """
    Evaluation
    :param data_name: name of the data files
    :param checkpoint_file: which checkpoint file to use
    :param data_folder: folder where data is stored
    :param beam_size: beam size at which to generate captions for evaluation
    :param outdir: place where the outputs are stored, so the checkpoint file
    :return: Official MSCOCO evaluator scores - bleu4, cider, rouge, meteor
    """
    global word_map
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_model():
        # Load model using checkpoint file provided
        torch.nn.Module.dump_patches = True
        checkpoint = torch.load(os.path.join(outdir, checkpoint_file),
                                map_location=device)
        decoder = checkpoint['decoder']
        decoder = decoder.to(device)
        decoder.eval()
        return decoder

    def load_dictionary():
        # Load word map (word2ix) using data folder provided
        word_map_file = os.path.join(data_folder,
                                     'WORDMAP_' + data_name + '.json')
        with open(word_map_file, 'r') as j:
            word_map = json.load(j)
        rev_word_map = {v: k for k, v in word_map.items()}
        vocab_size = len(word_map)
        return word_map, rev_word_map, vocab_size

    decoder = load_model()
    word_map, rev_word_map, vocab_size = load_dictionary()

    # DataLoader
    loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder, data_name, 'TEST'),
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=1,
                                         collate_fn=collate_fn,
                                         pin_memory=torch.cuda.is_available())

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
    references = list()
    hypotheses = list()

    # For each image
    for caption_idx, (image_features, caps, caplens, orig_caps) in enumerate(
            tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):

        if caption_idx % 5 != 0:
            continue

        k = beam_size

        # Move to GPU device, if available
        image_features = image_features.to(device)  # (1, 36, 2048)
        image_features_mean = image_features.mean(1)
        image_features_mean = image_features_mean.expand(k, 2048)

        # Tensor to store top k previous words at each step; now they're just <start>
        k_prev_words = torch.tensor([[word_map['<start>']]] * k,
                                    dtype=torch.long).to(device)  # (k, 1)

        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words  # (k, 1)

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h1, c1 = decoder.init_hidden_state(k)  # (batch_size, decoder_dim)
        h2, c2 = decoder.init_hidden_state(k)

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:

            embeddings = decoder.embedding(k_prev_words).squeeze(
                1)  # (s, embed_dim)
            h1, c1 = decoder.top_down_attention(
                torch.cat([h2, image_features_mean, embeddings], dim=1),
                (h1, c1))  # (batch_size_t, decoder_dim)
            trans_obj = decoder.transformer_encoder(
                image_features.transpose(0, 1)).transpose(0, 1)
            attention_weighted_encoding = decoder.attention(trans_obj, h1)
            h2, c2 = decoder.language_model(
                torch.cat([attention_weighted_encoding, h1], dim=1), (h2, c2))
            scores = decoder.fc(h2)  # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)

            # Add
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True,
                                                           True)  # (s)
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(
                    k, 0, True, True)  # (s)

            # Convert unrolled indices to actual indices of scores
            prev_word_inds = top_k_words / vocab_size  # (s)
            next_word_inds = top_k_words % vocab_size  # (s)

            # Add new words to sequences
            seqs = torch.cat(
                [seqs[prev_word_inds],
                 next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [
                ind for ind, next_word in enumerate(next_word_inds)
                if next_word != word_map['<end>']
            ]
            complete_inds = list(
                set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h1 = h1[prev_word_inds[incomplete_inds]]
            c1 = c1[prev_word_inds[incomplete_inds]]
            h2 = h2[prev_word_inds[incomplete_inds]]
            c2 = c2[prev_word_inds[incomplete_inds]]
            image_features_mean = image_features_mean[
                prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            # Break if things have been going on too long
            if step > 50:
                break
            step += 1

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # References
        # img_caps = [' '.join(c) for c in orig_caps]
        img_caps = [c for c in orig_caps]
        references.append(img_caps)

        # Hypotheses
        hypothesis = ([
            rev_word_map[w] for w in seq if w not in
            {word_map['<start>'], word_map['<end>'], word_map['<pad>']}
        ])
        # hypothesis = ' '.join(hypothesis)
        hypotheses.append(hypothesis)
        assert len(references) == len(hypotheses)

    # Calculate scores
    # metrics_dict = nlgeval.compute_metrics(references, hypotheses)
    hypotheses_file = os.path.join(outdir, 'hypotheses',
                                   'TEST.Hypotheses.json')
    references_file = os.path.join(outdir, 'references',
                                   'TEST.References.json')
    create_captions_file(range(len(hypotheses)), hypotheses, hypotheses_file)
    create_captions_file(range(len(references)), references, references_file)
    coco = COCO(references_file)
    # add the predicted results to the object
    coco_results = coco.loadRes(hypotheses_file)
    # create the evaluation object with both the ground-truth and the predictions
    coco_eval = COCOEvalCap(coco, coco_results)
    # change to use the image ids in the results object, not those from the ground-truth
    coco_eval.params['image_id'] = coco_results.getImgIds()
    # run the evaluation
    coco_eval.evaluate(verbose=False,
                       metrics=['bleu', 'meteor', 'rouge', 'cider'])
    # Results contains: "Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4", "METEOR", "ROUGE_L", "CIDEr", "SPICE"
    results = coco_eval.eval
    return results
예제 #14
0
def main():
    """
    Training and validation.
    """

    global word_map, word_map_inv, scene_graph

    # Read word map
    word_map_file = os.path.join(args.data_folder, 'WORDMAP_' + args.data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)
        # create inverse word map
    word_map_inv = {v: k for k, v in word_map.items()}

    # Initialize / load checkpoint
    if args.checkpoint is None:
        if args.architecture == 'bottomup_topdown':
            decoder = BUTDDecoder(attention_dim=args.attention_dim,
                                  embed_dim=args.emb_dim,
                                  decoder_dim=args.decoder_dim,
                                  vocab_size=len(word_map),
                                  dropout=args.dropout)
            scene_graph = False
        elif args.architecture == 'io':
            decoder = IODecoder(attention_dim=args.attention_dim,
                                embed_dim=args.emb_dim,
                                decoder_dim=args.decoder_dim,
                                vocab_size=len(word_map),
                                dropout=args.dropout,
                                use_obj_info=args.use_obj_info,
                                use_rel_info=args.use_rel_info,
                                k_update_steps=args.k_update_steps,
                                update_relations=args.update_relations)
            scene_graph = True
        elif args.architecture == 'transformer':
            decoder = TransDecoder(attention_dim=args.attention_dim,
                                   embed_dim=args.emb_dim,
                                   decoder_dim=args.decoder_dim,
                                   transformer_dim=args.transformer_dim,
                                   vocab_size=len(word_map),
                                   dropout=args.dropout,
                                   n_heads=args.num_heads,
                                   n_layers=args.num_layers)
            scene_graph = False
        else:
            exit('unknown architecture chosen')

        decoder_optimizer = torch.optim.Adamax(params=filter(lambda p: p.requires_grad, decoder.parameters()))
        tracking = {'eval': [], 'test': None}
        start_epoch = 0
        best_epoch = -1
        epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation
        best_stopping_score = 0.  # stopping_score right now
    else:
        checkpoint = torch.load(args.checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        args.stopping_metric = checkpoint['stopping_metric'],
        best_stopping_score = checkpoint['metric_score'],
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer'],
        tracking = checkpoint['tracking'],
        best_epoch = checkpoint['best_epoch']

    # Move to GPU, if available
    decoder = decoder.to(device)

    # Loss functions
    criterion_ce = nn.CrossEntropyLoss().to(device)
    criterion_dis = nn.MultiLabelMarginLoss().to(device)

    # Custom dataloaders
    train_loader = torch.utils.data.DataLoader(CaptionDataset(args.data_folder, args.data_name, 'TRAIN',
                                                              scene_graph=scene_graph),
                                               batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(args.data_folder, args.data_name, 'VAL',
                                                            scene_graph=scene_graph),
                                             collate_fn=collate_fn,
                                             # use our specially designed collate function with valid/test only
                                             batch_size=1, shuffle=False,
                                             num_workers=args.workers, pin_memory=True)
    #    batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, args.epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == args.patience:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              decoder=decoder,
              criterion_ce=criterion_ce,
              criterion_dis=criterion_dis,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_results = validate(val_loader=val_loader,
                                  decoder=decoder,
                                  criterion_ce=criterion_ce,
                                  criterion_dis=criterion_dis,
                                  epoch=epoch)
        tracking['eval'] = recent_results
        recent_stopping_score = recent_results[args.stopping_metric]

        # Check if there was an improvement
        is_best = recent_stopping_score > best_stopping_score
        best_stopping_score = max(recent_stopping_score, best_stopping_score)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0
            best_epoch = epoch

        # Save checkpoint
        save_checkpoint(args.data_name, epoch, epochs_since_improvement, decoder, decoder_optimizer,
                        args.stopping_metric, best_stopping_score, tracking, is_best, args.outdir, best_epoch)

    # if needed, run an beamsearch evaluation on the test set
    if args.test_at_end:
        checkpoint_file = 'BEST_' + str(best_epoch) + '_' + 'checkpoint_' + args.data_name + '.pth.tar'
        results = beam_evaluate_butd(args.data_name, checkpoint_file, args.data_folder, args.beam_size, args.outdir)
        tracking['test'] = results
    with open(os.path.join(args.outdir, 'TRACKING.'+args.data_name+'.pkl'), 'wb') as f:
        pickle.dump(tracking, f)