Exemple #1
0
    def __init__(self,
                 split,
                 transform=None,
                 word_encoder=None,
                 is_train=True,
                 caption_max_len=35):
        """
        :param split: split, one of 'train', 'val', 'test'
        :param transform: image transform pipeline
        """
        TextureDescriptionData.__init__(self, phid_format=None)
        self.transform = transform
        self.is_train = is_train
        self.caption_max_len = caption_max_len
        self.split = split
        assert self.split in ('train', 'val', 'test')

        self.word_encoder = word_encoder
        if self.word_encoder is None:
            self.word_encoder = WordEncoder()

        self.img_desc_ids = list()
        for img_i, img_name in enumerate(self.img_splits[split]):
            desc_num = len(self.img_data_dict[img_name]['descriptions'])
            self.img_desc_ids += [(img_i, desc_i)
                                  for desc_i in range(desc_num)]
Exemple #2
0
 def __init__(self, word_emb_dim, word_encoder=None):
     super(MeanEncoder, self).__init__()
     self.word_encoder = word_encoder
     if self.word_encoder is None:
         self.word_encoder = WordEncoder()
     self.embeds = nn.Embedding(num_embeddings=len(
         self.word_encoder.word_list),
                                embedding_dim=word_emb_dim)
     self.out_dim = word_emb_dim
Exemple #3
0
 def forward(self, sentences, device='cuda'):
     """
     sentences: list[str], len of list: B
     output sent_embs: Tensor B x OUT
     """
     sentences = [WordEncoder.tokenize(s) for s in sentences]
     # sentences = [['First', 'sentence', '.'], ['Another', '.']]
     # use batch_to_ids to convert sentences to character ids
     character_ids = batch_to_ids(sentences).to(device)
     embeddings = self.elmo(character_ids)
     # embeddings['elmo_representations'] is length two list of tensors.
     # Each element contains one layer of ELMo representations with shape
     # (2, 3, 1024).
     #   2    - the batch size
     #   3    - the sequence length of the batch
     #   1024 - the length of each ELMo vector
     sent_embeds = embeddings['elmo_representations'][1]  # B x max_l x 1024
     sent_emb_list = list()
     for si in range(len(sentences)):
         sent_len = len(sentences[si])
         sent_embed = torch.mean(sent_embeds[si, :sent_len, :],
                                 dim=0)  # 1024
         sent_emb_list.append(sent_embed)
     sent_embs = torch.stack(sent_emb_list, dim=0)  # B x 1024
     return sent_embs
Exemple #4
0
def add_space_to_cap_dict(cap_dict):
    new_dict = dict()
    for img_name, caps in cap_dict.items():
        new_dict[img_name] = list()
        for cap in caps:
            tokens = WordEncoder.tokenize(cap)
            if len(tokens) > 0:
                new_cap = ' '.join(tokens)
            else:
                new_cap = cap
            new_dict[img_name].append(new_cap)
    return new_dict
Exemple #5
0
 def __init__(self,
              word_emb_dim,
              hidden_dim=256,
              bi_direct=True,
              word_encoder=None):
     super(LSTMEncoder, self).__init__()
     self.word_encoder = word_encoder
     if self.word_encoder is None:
         self.word_encoder = WordEncoder()
     self.embeds = nn.Embedding(num_embeddings=len(
         self.word_encoder.word_list),
                                embedding_dim=word_emb_dim,
                                padding_idx=0)
     self.lstm = torch.nn.LSTM(input_size=word_emb_dim,
                               hidden_size=hidden_dim,
                               bidirectional=bi_direct,
                               batch_first=True,
                               num_layers=1,
                               bias=True,
                               dropout=0.0)
     self.out_dim = hidden_dim
     if bi_direct:
         self.out_dim = hidden_dim * 2
Exemple #6
0
class LSTMEncoder(nn.Module):
    def __init__(self,
                 word_emb_dim,
                 hidden_dim=256,
                 bi_direct=True,
                 word_encoder=None):
        super(LSTMEncoder, self).__init__()
        self.word_encoder = word_encoder
        if self.word_encoder is None:
            self.word_encoder = WordEncoder()
        self.embeds = nn.Embedding(num_embeddings=len(
            self.word_encoder.word_list),
                                   embedding_dim=word_emb_dim,
                                   padding_idx=0)
        self.lstm = torch.nn.LSTM(input_size=word_emb_dim,
                                  hidden_size=hidden_dim,
                                  bidirectional=bi_direct,
                                  batch_first=True,
                                  num_layers=1,
                                  bias=True,
                                  dropout=0.0)
        self.out_dim = hidden_dim
        if bi_direct:
            self.out_dim = hidden_dim * 2
        # self.lstm = PytorchSeq2VecWrapper(module=lstm_module)
        # self.out_dim = self.lstm.get_output_dim()

    def forward(self, sentences):
        """
        sentences: list[str], len of list: B
        output sent_embs: Tensor B x OUT
        """
        device = self.embeds.weight.device
        encode_padded, lens = self.word_encoder.encode_pad(sentences)
        encode_padded = torch.as_tensor(encode_padded,
                                        dtype=torch.long,
                                        device=device)  # B x max_len
        word_embs = self.embeds(encode_padded)  # B x max_len x E
        packed = pack_padded_sequence(word_embs,
                                      lens,
                                      batch_first=True,
                                      enforce_sorted=False)
        lstm_embs_packed, _ = self.lstm(packed)
        lstm_embs, lens = pad_packed_sequence(
            lstm_embs_packed, batch_first=True)  # B x max_len x OUT
        sent_embs = torch.stack(
            [lstm_embs[i, lens[i] - 1] for i in range(len(lens))])
        # print_tensor_stats(sent_embs, 'sent_embs')
        return sent_embs
Exemple #7
0
class MeanEncoder(nn.Module):
    def __init__(self, word_emb_dim, word_encoder=None):
        super(MeanEncoder, self).__init__()
        self.word_encoder = word_encoder
        if self.word_encoder is None:
            self.word_encoder = WordEncoder()
        self.embeds = nn.Embedding(num_embeddings=len(
            self.word_encoder.word_list),
                                   embedding_dim=word_emb_dim)
        self.out_dim = word_emb_dim

    def forward(self, sentences):
        """
        sentences: list[str], len of list: B
        output: mean_embed: including embed of <end>, not including <start>
        """
        device = self.embeds.weight.device
        # encoded = list()
        # for sent in sentences:
        #     e = self.word_encoder.encode(sent, max_len=-1)
        #     t = torch.as_tensor(e, dtype=torch.long, device=device)
        #     encoded.append(t)
        encoded = [
            torch.as_tensor(self.word_encoder.encode(sent, max_len=-1)[0],
                            dtype=torch.long,
                            device=device) for sent in sentences
        ]
        sent_lengths = torch.as_tensor([len(e) for e in encoded],
                                       dtype=torch.long,
                                       device=device)
        sent_end_ids = torch.cumsum(sent_lengths, dim=0)
        sent_start_ids = torch.empty_like(sent_end_ids)
        sent_start_ids[0] = 0
        sent_start_ids[1:] = sent_end_ids[:-1]

        encoded = torch.cat(encoded)
        embeded = self.embeds(encoded)  # sum_len x E
        sum_embeds = torch.cumsum(embeded, dim=0)  # sum_len x E
        sum_embed = sum_embeds.index_select(dim=0, index=sent_end_ids - 1) - \
                    sum_embeds.index_select(dim=0, index=sent_start_ids)  # exclude <start>
        mean_embed = sum_embed / sent_lengths.unsqueeze(-1).float()  # B x E
        return mean_embed
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, word_encoder
    word_encoder = WordEncoder()

    # Initialize / load checkpoint
    if checkpoint is None:
        print("Starting training from scratch!")
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_encoder.word_list),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder: Encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None
        if init_word_emb == 'fast_text':
            word_emb = get_word_embed(word_encoder.word_list, 'fast_text')
            decoder.embedding.weight.data.copy_(torch.from_numpy(word_emb))

    else:
        print("Loading from checkpoint " + checkpoint)
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        split='train',
        is_train=True,
        word_encoder=word_encoder,
        transform=build_transforms(is_train=fine_tune_encoder)),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        split='val',
        is_train=False,
        word_encoder=word_encoder,
        transform=build_transforms(is_train=False)),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True,
                                             collate_fn=caption_collate)

    print('ready to train.')
    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_dir = 'output/show_attend_tell/checkpoints'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_checkpoint(checkpoint_base_name,
                        epoch,
                        epochs_since_improvement,
                        encoder,
                        decoder,
                        encoder_optimizer,
                        decoder_optimizer,
                        recent_bleu4,
                        is_best,
                        save_dir='output/show_attend_tell/checkpoints')
Exemple #9
0
def compare_visualize(split='test',
                      html_path='visualizations/caption.html',
                      visualize_count=100):
    dataset = TextureDescriptionData()
    word_encoder = WordEncoder()
    # cls_predictions = top_k_caption(top_k=5, model_type='cls', dataset=dataset, split=split)
    # with open('output/naive_classify/v1_35_ft2,4_fc512_tuneTrue/caption_top5_%s.json' % split, 'w') as f:
    #     json.dump(cls_predictions, f)
    # tri_predictions = top_k_caption(top_k=5, model_type='tri', dataset=dataset, split=split)
    # with open('output/triplet_match/c34_bert_l2_s_lr0.00001/caption_top5_%s.json' % split, 'w') as f:
    #     json.dump(tri_predictions, f)
    cls_predictions = json.load(
        open(
            'output/naive_classify/v1_35_ft2,4_fc512_tuneTrue/caption_top5_%s.json'
            % split))
    tri_predictions = json.load(
        open(
            'output/triplet_match/c34_bert_l2_s_lr0.00001/caption_top5_%s.json'
            % split))
    sat_predictions = json.load(
        open('output/show_attend_tell/results/pred_v2_last_beam5_%s.json' %
             split))
    pred_dicts = [cls_predictions, tri_predictions, sat_predictions]
    img_pref = 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/'

    html_str = '''<!DOCTYPE html>
<html lang="en">
<head>
    <title>Caption visualize</title>
    <link rel="stylesheet" href="caption_style.css">
</head>
<body>
<table>
    <col class="column-one">
    <col class="column-two">
    <col class="column-three">
    <tr>
        <th style="text-align: center">Image</th>
        <th>Predicted captions</th>
        <th>Ground-truth descriptions</th>
    </tr>
'''

    for img_i, img_name in enumerate(dataset.img_splits[split]):
        gt_descs = dataset.img_data_dict[img_name]['descriptions']
        gt_desc_str = '|'.join(gt_descs)
        gt_html_str = ''
        for ci, cap in enumerate(gt_descs):
            gt_html_str += '[%d] %s<br>\n' % (ci + 1, cap)

        pred_caps = [pred_dict[img_name][0] for pred_dict in pred_dicts]
        for ci, cap in enumerate(pred_caps):
            tokens = WordEncoder.tokenize(cap)
            for ti, t in enumerate(tokens):
                if t in gt_desc_str and len(t) > 1:
                    tokens[ti] = '<span class="correct">%s</span>' % t
            pred_caps[ci] = word_encoder.detokenize(tokens)
        html_str += '''
<tr>
    <td>
        <img src={img_pref}{img_name} alt="{img_name}">
    </td>
    <td>
        <span class="pred_name">Classifier top 5:</span><br>
        {pred0}<br>
        <span class="pred_name">Triplet top 5:</span><br>
        {pred1}<br>
        <span class="pred_name">Show-attend-tell:</span><br>
        {pred2}<br>
    </td>
    <td>
        {gt}
    </td>
</tr>
'''.format(img_pref=img_pref,
           img_name=img_name,
           pred0=pred_caps[0],
           pred1=pred_caps[1],
           pred2=pred_caps[2],
           gt=gt_html_str)

        if img_i >= visualize_count:
            break

    html_str += '</table>\n</body\n></html>'
    with open(html_path, 'w') as f:
        f.write(html_str)

    return
Exemple #10
0
# checkpoint = 'output/show_attend_tell/checkpoints/BEST_checkpoint_v1_tuneResNetTrue.pth.tar'  # model checkpoint
checkpoint = 'output/show_attend_tell/checkpoints/checkpoint_v2_FastText.pth.tar'  # model checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Load model
checkpoint = torch.load(checkpoint)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()

# Load word map (word2ix)
word_encoder = WordEncoder()


def predict(beam_size, img_dataset=None, split=split):
    # DataLoader
    if img_dataset is None:
        img_dataset = ImgOnlyDataset(split=split, transform=build_transforms(is_train=False))
    loader = torch.utils.data.DataLoader(img_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

    # TODO: Batched Beam Search
    # Therefore, do not use a batch_size greater than 1 - IMPORTANT!

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
Exemple #11
0
def train():
    from torch.utils.tensorboard import SummaryWriter
    # load configs
    parser = argparse.ArgumentParser(description="Triplet Matching Training")
    parser.add_argument('-c', '--config_file', default=None, help="path to config file")
    parser.add_argument('-o', '--opts', default=None, nargs=argparse.REMAINDER,
                        help="Modify config options using the command-line. E.g. TRAIN.INIT_LR 0.01",)
    args = parser.parse_args()

    if args.config_file is not None:
        cfg.merge_from_file(args.config_file)
    if args.opts is not None:
        cfg.merge_from_list(args.opts)

    prepare(cfg)
    cfg.freeze()
    print(cfg.dump())

    if not os.path.exists(cfg.OUTPUT_PATH):
        os.makedirs(cfg.OUTPUT_PATH)
    with open(os.path.join(cfg.OUTPUT_PATH, 'train.yml'), 'w') as f:
        f.write(cfg.dump())

    # set random seed
    torch.manual_seed(cfg.RAND_SEED)
    np.random.seed(cfg.RAND_SEED)
    random.seed(cfg.RAND_SEED)

    # make data_loader, model, criterion, optimizer
    dataset = TripletTrainData(split=cfg.TRAIN_SPLIT, neg_img=cfg.LOSS.IMG_SENT_WEIGHTS[0] > 0,
                               neg_lang=cfg.LOSS.IMG_SENT_WEIGHTS[1] > 0, lang_input=cfg.LANG_INPUT)
    data_loader = DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True)

    img_datset = ImgOnlyDataset(split=cfg.EVAL_SPLIT, transform=build_transforms(is_train=False),
                                texture_dataset=dataset)
    eval_img_dataloader = DataLoader(img_datset, batch_size=1, shuffle=False)

    phrase_dataset = PhraseOnlyDataset(texture_dataset=dataset)
    eval_phrase_dataloader = DataLoader(phrase_dataset, batch_size=32, shuffle=False)

    word_encoder = WordEncoder()
    model: TripletMatch = TripletMatch(vec_dim=cfg.MODEL.VEC_DIM, neg_margin=cfg.LOSS.MARGIN,
                                       distance=cfg.MODEL.DISTANCE, img_feats=cfg.MODEL.IMG_FEATS,
                                       lang_encoder_method=cfg.MODEL.LANG_ENCODER, word_encoder=word_encoder)

    if cfg.INIT_WORD_EMBED != 'rand' and cfg.MODEL.LANG_ENCODER in ['mean', 'lstm']:
        word_emb = get_word_embed(word_encoder.word_list, cfg.INIT_WORD_EMBED)
        model.lang_embed.embeds.weight.data.copy_(torch.from_numpy(word_emb))
    if len(cfg.LOAD_WEIGHTS) > 0:
        model.load_state_dict(torch.load(cfg.LOAD_WEIGHTS))

    model.train()
    device = torch.device(cfg.DEVICE)
    model.to(device)

    if not cfg.TRAIN.TUNE_RESNET:
        model.resnet_encoder.requires_grad = False
        model.resnet_encoder.eval()
    if not cfg.TRAIN.TUNE_LANG_ENCODER:
        model.lang_embed.requires_grad = False
        model.lang_embed.eval()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                 lr=cfg.TRAIN.INIT_LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                                 betas=(cfg.TRAIN.ADAM.ALPHA, cfg.TRAIN.ADAM.BETA),
                                 eps=cfg.TRAIN.ADAM.EPSILON)

    # make tensorboard writer and dirs
    checkpoint_dir = os.path.join(cfg.OUTPUT_PATH, 'checkpoints')
    vis_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_%s_LAST' % cfg.EVAL_SPLIT)
    best_vis_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_%s_BEST' % cfg.EVAL_SPLIT)
    tb_dir = os.path.join(cfg.OUTPUT_PATH, 'tensorboard')
    tb_writer = SummaryWriter(tb_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if not os.path.exists(tb_dir):
        os.makedirs(tb_dir)

    # training loop
    step = 1
    epoch = 1
    epoch_float = 0.0
    epoch_per_step = cfg.TRAIN.BATCH_SIZE * 1.0 / len(dataset)
    best_eval_metric = 0
    best_metrics = None
    best_eval_count = 0
    early_stop = False
    while epoch <= cfg.TRAIN.MAX_EPOCH and not early_stop:
        # for pos_imgs, pos_langs, neg_imgs, neg_langs in tqdm(data_loader, desc='TRAIN epoch %d' % epoch):
        for pos_imgs, pos_langs, neg_imgs, neg_langs in data_loader:
            pos_imgs = pos_imgs.to(device)
            if neg_imgs is not None and neg_imgs[0] is not None:
                neg_imgs = neg_imgs.to(device)

            verbose = step <= 5 or step % 50 == 0

            neg_img_loss, neg_lang_loss = model(pos_imgs, pos_langs, neg_imgs, neg_langs, verbose=verbose)

            loss = cfg.LOSS.IMG_SENT_WEIGHTS[0] * neg_img_loss + cfg.LOSS.IMG_SENT_WEIGHTS[1] * neg_lang_loss
            loss /= sum(cfg.LOSS.IMG_SENT_WEIGHTS)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            lr = optimizer.param_groups[0]['lr']
            tb_writer.add_scalar('train/loss', loss, step)
            tb_writer.add_scalar('train/neg_img_loss', neg_img_loss, step)
            tb_writer.add_scalar('train/neg_lang_loss', neg_lang_loss, step)
            tb_writer.add_scalar('train/lr', lr, step)

            if verbose:
                print('[%s] epoch-%d step-%d: loss %.4f (neg_img: %.4f, neg_lang: %.4f); lr %.1E'
                      % (time.strftime('%m/%d %H:%M:%S'), epoch, step, loss, neg_img_loss, neg_lang_loss, lr))

            # if epoch == 1 and step == 2:  # debug eval
            #     visualize_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_debug')
            #     do_eval(model, eval_img_dataloader, eval_phrase_dataloader, device, split=cfg.EVAL_SPLIT,
            #             visualize_path=visualize_path, add_to_summary_name=None)

            if epoch_float % cfg.TRAIN.EVAL_EVERY_EPOCH < epoch_per_step and epoch_float > 0:
                p2i_result, i2p_result = do_eval(model, eval_img_dataloader, eval_phrase_dataloader, device,
                                                 split=cfg.EVAL_SPLIT, visualize_path=vis_path)

                for m, v in p2i_result.items():
                    tb_writer.add_scalar('eval_p2i/%s' % m, v, step)
                for m, v in i2p_result.items():
                    tb_writer.add_scalar('eval_i2p/%s' % m, v, step)

                eval_metric = p2i_result['mean_average_precision'] + i2p_result['mean_average_precision']
                if eval_metric > best_eval_metric:
                    print('EVAL: new best!')
                    best_eval_metric = eval_metric
                    best_metrics = (p2i_result, i2p_result)
                    best_eval_count = 0
                    copy_tree(vis_path, best_vis_path, update=1)
                    with open(os.path.join(checkpoint_dir, 'epoch_step.txt'), 'w') as f:
                        f.write('BEST: epoch {}, step {}\n'.format(epoch, step))
                    torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'BEST_checkpoint.pth'))

                else:
                    best_eval_count += 1
                    print('EVAL: since last best: %d' % best_eval_count)
                    if epoch_float % cfg.TRAIN.CHECKPOINT_EVERY_EPOCH < epoch_per_step and epoch_float > 0:
                        with open(os.path.join(checkpoint_dir, 'epoch_step.txt'), 'a') as f:
                            f.write('LAST: epoch {}, step {}\n'.format(epoch, step))
                        torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'LAST_checkpoint.pth'))

                if best_eval_count % cfg.TRAIN.LR_DECAY_EVAL_COUNT == 0 and best_eval_count > 0:
                    print('EVAL: lr decay triggered')
                    for param_group in optimizer.param_groups:
                        param_group['lr'] *= cfg.TRAIN.LR_DECAY_GAMMA

                if best_eval_count % cfg.TRAIN.EARLY_STOP_EVAL_COUNT == 0 and best_eval_count > 0:
                    print('EVAL: early stop triggered')
                    early_stop = True
                    break

                model.train()
                if not cfg.TRAIN.TUNE_RESNET:
                    model.resnet_encoder.eval()
                if not cfg.TRAIN.TUNE_LANG_ENCODER:
                    model.lang_embed.eval()

            step += 1
            epoch_float += epoch_per_step

        epoch += 1

    tb_writer.close()
    if best_metrics is not None:
        exp_name = '%s:%s' % (cfg.OUTPUT_PATH, cfg.EVAL_SPLIT)
        log_to_summary(exp_name, best_metrics[0], best_metrics[1])
    return best_metrics
Exemple #12
0
def caption_visualize(pred_dicts=None, filter=False):
    if pred_dicts is None:
        pred_dicts = np.load('applications/synthetic_imgs/visualizations/results/caption.npy')
    syn_dataset = SyntheticData()
    word_encoder = WordEncoder()
    img_pref = '../modified_imgs/'
    html_str = '''<!DOCTYPE html>
    <html lang="en">
    <head>
        <title>Caption visualize</title>
        <style>
        .correct {
            font-weight: bold;
        }
        .pred_name {
            color: ROYALBLUE;
            font-weight: bold;
        }
        img {
           width: 3cm
        }
        table {
            border-collapse: collapse;
        }
        tr {
            border-bottom: 1px solid lightgray;
        }

        </style>
    </head>
    <body>
    <table>
        <col class="column-one">
        <col class="column-two">
        <tr>
            <th style="text-align: center">Image</th>
            <th>Predicted captions</th>
        </tr>
    '''
    for idx in range(len(syn_dataset)):
        pred_caps = [pred_dict[idx][0] for pred_dict in pred_dicts]
        img_i, c1_i, c2_i = syn_dataset.unravel_index(idx)

        if filter:
            good_cap = False
            c1 = syn_dataset.color_names[c1_i]
            c2 = syn_dataset.color_names[c2_i]
            for ci, cap in enumerate(pred_caps):
                if c1 in cap and c2 in cap:
                    good_cap = True
                    break
            if not good_cap:
                continue

        img_name = '%s_%s_%s.jpg' % (syn_dataset.img_names[img_i].split('.')[0],
                                     syn_dataset.color_names[c1_i],
                                     syn_dataset.color_names[c2_i])
        gt_desc = syn_dataset.get_desc(img_i, c1_i, c2_i)

        for ci, cap in enumerate(pred_caps):
            tokens = WordEncoder.tokenize(cap)
            for ti, t in enumerate(tokens):
                if t in gt_desc and len(t) > 1:
                    tokens[ti] = '<span class="correct">%s</span>' % t
            pred_caps[ci] = word_encoder.detokenize(tokens)
        html_str += '''
    <tr>
        <td>
            <img src={img_pref}{img_name} alt="{img_name}">
        </td>
        <td>
            <span class="pred_name">Synthetic Ground-truth Description:</span><br>
            {gt}<br>
            <span class="pred_name">Classifier top 5:</span><br>
            {pred0}<br>
            <span class="pred_name">Triplet top 5:</span><br>
            {pred1}<br>
            <span class="pred_name">Show-attend-tell:</span><br>
            {pred2}<br>
        </td>
    </tr>
    '''.format(img_pref=img_pref, img_name=img_name, pred0=pred_caps[0], pred1=pred_caps[1], pred2=pred_caps[2],
               gt=gt_desc)
    html_str += '</table>\n</body\n></html>'

    html_name = 'caption.html'
    if filter:
        html_name = 'caption_filtered.html'
    with open('applications/synthetic_imgs/visualizations/results/' + html_name, 'w') as f:
        f.write(html_str)
    return
Exemple #13
0
class CaptionDataset(Dataset, TextureDescriptionData):
    """
    A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
    """
    def __init__(self,
                 split,
                 transform=None,
                 word_encoder=None,
                 is_train=True,
                 caption_max_len=35):
        """
        :param split: split, one of 'train', 'val', 'test'
        :param transform: image transform pipeline
        """
        TextureDescriptionData.__init__(self, phid_format=None)
        self.transform = transform
        self.is_train = is_train
        self.caption_max_len = caption_max_len
        self.split = split
        assert self.split in ('train', 'val', 'test')

        self.word_encoder = word_encoder
        if self.word_encoder is None:
            self.word_encoder = WordEncoder()

        self.img_desc_ids = list()
        for img_i, img_name in enumerate(self.img_splits[split]):
            desc_num = len(self.img_data_dict[img_name]['descriptions'])
            self.img_desc_ids += [(img_i, desc_i)
                                  for desc_i in range(desc_num)]

    def __getitem__(self, i):
        img_i, desc_i = self.img_desc_ids[i]
        img_data = self.get_split_data(split=self.split,
                                       img_idx=img_i,
                                       load_img=True)
        img = img_data['image']
        if self.transform is not None:
            img = self.transform(img)

        desc = img_data['descriptions'][desc_i]
        caption, caplen = self.word_encoder.encode(
            lang_input=desc, max_len=self.caption_max_len)
        caplen = torch.as_tensor([caplen], dtype=torch.long)
        caption = torch.as_tensor(caption, dtype=torch.long)

        if self.is_train:
            return img, caption, caplen
        else:
            # For validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score
            all_captions = list()
            mlen = 0
            for desc in img_data['descriptions']:
                c, cl = self.word_encoder.encode(lang_input=desc,
                                                 max_len=self.caption_max_len)
                all_captions.append(c)
                mlen = max(mlen, cl)

            all_captions_np = np.zeros((len(all_captions), mlen))
            for ci, c in enumerate(all_captions):
                cl = min(len(c), mlen)
                all_captions_np[ci, :cl] = c[:cl]
            all_captions = torch.as_tensor(all_captions_np, dtype=torch.long)
            return img, caption, caplen, all_captions

    def __len__(self):
        return len(self.img_desc_ids)