def __init__(self, split, transform=None, word_encoder=None, is_train=True, caption_max_len=35): """ :param split: split, one of 'train', 'val', 'test' :param transform: image transform pipeline """ TextureDescriptionData.__init__(self, phid_format=None) self.transform = transform self.is_train = is_train self.caption_max_len = caption_max_len self.split = split assert self.split in ('train', 'val', 'test') self.word_encoder = word_encoder if self.word_encoder is None: self.word_encoder = WordEncoder() self.img_desc_ids = list() for img_i, img_name in enumerate(self.img_splits[split]): desc_num = len(self.img_data_dict[img_name]['descriptions']) self.img_desc_ids += [(img_i, desc_i) for desc_i in range(desc_num)]
def __init__(self, word_emb_dim, word_encoder=None): super(MeanEncoder, self).__init__() self.word_encoder = word_encoder if self.word_encoder is None: self.word_encoder = WordEncoder() self.embeds = nn.Embedding(num_embeddings=len( self.word_encoder.word_list), embedding_dim=word_emb_dim) self.out_dim = word_emb_dim
def forward(self, sentences, device='cuda'): """ sentences: list[str], len of list: B output sent_embs: Tensor B x OUT """ sentences = [WordEncoder.tokenize(s) for s in sentences] # sentences = [['First', 'sentence', '.'], ['Another', '.']] # use batch_to_ids to convert sentences to character ids character_ids = batch_to_ids(sentences).to(device) embeddings = self.elmo(character_ids) # embeddings['elmo_representations'] is length two list of tensors. # Each element contains one layer of ELMo representations with shape # (2, 3, 1024). # 2 - the batch size # 3 - the sequence length of the batch # 1024 - the length of each ELMo vector sent_embeds = embeddings['elmo_representations'][1] # B x max_l x 1024 sent_emb_list = list() for si in range(len(sentences)): sent_len = len(sentences[si]) sent_embed = torch.mean(sent_embeds[si, :sent_len, :], dim=0) # 1024 sent_emb_list.append(sent_embed) sent_embs = torch.stack(sent_emb_list, dim=0) # B x 1024 return sent_embs
def add_space_to_cap_dict(cap_dict): new_dict = dict() for img_name, caps in cap_dict.items(): new_dict[img_name] = list() for cap in caps: tokens = WordEncoder.tokenize(cap) if len(tokens) > 0: new_cap = ' '.join(tokens) else: new_cap = cap new_dict[img_name].append(new_cap) return new_dict
def __init__(self, word_emb_dim, hidden_dim=256, bi_direct=True, word_encoder=None): super(LSTMEncoder, self).__init__() self.word_encoder = word_encoder if self.word_encoder is None: self.word_encoder = WordEncoder() self.embeds = nn.Embedding(num_embeddings=len( self.word_encoder.word_list), embedding_dim=word_emb_dim, padding_idx=0) self.lstm = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, bidirectional=bi_direct, batch_first=True, num_layers=1, bias=True, dropout=0.0) self.out_dim = hidden_dim if bi_direct: self.out_dim = hidden_dim * 2
class LSTMEncoder(nn.Module): def __init__(self, word_emb_dim, hidden_dim=256, bi_direct=True, word_encoder=None): super(LSTMEncoder, self).__init__() self.word_encoder = word_encoder if self.word_encoder is None: self.word_encoder = WordEncoder() self.embeds = nn.Embedding(num_embeddings=len( self.word_encoder.word_list), embedding_dim=word_emb_dim, padding_idx=0) self.lstm = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, bidirectional=bi_direct, batch_first=True, num_layers=1, bias=True, dropout=0.0) self.out_dim = hidden_dim if bi_direct: self.out_dim = hidden_dim * 2 # self.lstm = PytorchSeq2VecWrapper(module=lstm_module) # self.out_dim = self.lstm.get_output_dim() def forward(self, sentences): """ sentences: list[str], len of list: B output sent_embs: Tensor B x OUT """ device = self.embeds.weight.device encode_padded, lens = self.word_encoder.encode_pad(sentences) encode_padded = torch.as_tensor(encode_padded, dtype=torch.long, device=device) # B x max_len word_embs = self.embeds(encode_padded) # B x max_len x E packed = pack_padded_sequence(word_embs, lens, batch_first=True, enforce_sorted=False) lstm_embs_packed, _ = self.lstm(packed) lstm_embs, lens = pad_packed_sequence( lstm_embs_packed, batch_first=True) # B x max_len x OUT sent_embs = torch.stack( [lstm_embs[i, lens[i] - 1] for i in range(len(lens))]) # print_tensor_stats(sent_embs, 'sent_embs') return sent_embs
class MeanEncoder(nn.Module): def __init__(self, word_emb_dim, word_encoder=None): super(MeanEncoder, self).__init__() self.word_encoder = word_encoder if self.word_encoder is None: self.word_encoder = WordEncoder() self.embeds = nn.Embedding(num_embeddings=len( self.word_encoder.word_list), embedding_dim=word_emb_dim) self.out_dim = word_emb_dim def forward(self, sentences): """ sentences: list[str], len of list: B output: mean_embed: including embed of <end>, not including <start> """ device = self.embeds.weight.device # encoded = list() # for sent in sentences: # e = self.word_encoder.encode(sent, max_len=-1) # t = torch.as_tensor(e, dtype=torch.long, device=device) # encoded.append(t) encoded = [ torch.as_tensor(self.word_encoder.encode(sent, max_len=-1)[0], dtype=torch.long, device=device) for sent in sentences ] sent_lengths = torch.as_tensor([len(e) for e in encoded], dtype=torch.long, device=device) sent_end_ids = torch.cumsum(sent_lengths, dim=0) sent_start_ids = torch.empty_like(sent_end_ids) sent_start_ids[0] = 0 sent_start_ids[1:] = sent_end_ids[:-1] encoded = torch.cat(encoded) embeded = self.embeds(encoded) # sum_len x E sum_embeds = torch.cumsum(embeded, dim=0) # sum_len x E sum_embed = sum_embeds.index_select(dim=0, index=sent_end_ids - 1) - \ sum_embeds.index_select(dim=0, index=sent_start_ids) # exclude <start> mean_embed = sum_embed / sent_lengths.unsqueeze(-1).float() # B x E return mean_embed
def main(): """ Training and validation. """ global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, word_encoder word_encoder = WordEncoder() # Initialize / load checkpoint if checkpoint is None: print("Starting training from scratch!") decoder = DecoderWithAttention(attention_dim=attention_dim, embed_dim=emb_dim, decoder_dim=decoder_dim, vocab_size=len(word_encoder.word_list), dropout=dropout) decoder_optimizer = torch.optim.Adam(params=filter( lambda p: p.requires_grad, decoder.parameters()), lr=decoder_lr) encoder: Encoder = Encoder() encoder.fine_tune(fine_tune_encoder) encoder_optimizer = torch.optim.Adam( params=filter(lambda p: p.requires_grad, encoder.parameters()), lr=encoder_lr) if fine_tune_encoder else None if init_word_emb == 'fast_text': word_emb = get_word_embed(word_encoder.word_list, 'fast_text') decoder.embedding.weight.data.copy_(torch.from_numpy(word_emb)) else: print("Loading from checkpoint " + checkpoint) checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] best_bleu4 = checkpoint['bleu-4'] decoder = checkpoint['decoder'] decoder_optimizer = checkpoint['decoder_optimizer'] encoder = checkpoint['encoder'] encoder_optimizer = checkpoint['encoder_optimizer'] if fine_tune_encoder is True and encoder_optimizer is None: encoder.fine_tune(fine_tune_encoder) encoder_optimizer = torch.optim.Adam(params=filter( lambda p: p.requires_grad, encoder.parameters()), lr=encoder_lr) # Move to GPU, if available decoder = decoder.to(device) encoder = encoder.to(device) # Loss function criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_loader = torch.utils.data.DataLoader(CaptionDataset( split='train', is_train=True, word_encoder=word_encoder, transform=build_transforms(is_train=fine_tune_encoder)), batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(CaptionDataset( split='val', is_train=False, word_encoder=word_encoder, transform=build_transforms(is_train=False)), batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True, collate_fn=caption_collate) print('ready to train.') # Epochs for epoch in range(start_epoch, epochs): # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20 if epochs_since_improvement == 20: break if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0: adjust_learning_rate(decoder_optimizer, 0.8) if fine_tune_encoder: adjust_learning_rate(encoder_optimizer, 0.8) # One epoch's training train(train_loader=train_loader, encoder=encoder, decoder=decoder, criterion=criterion, encoder_optimizer=encoder_optimizer, decoder_optimizer=decoder_optimizer, epoch=epoch) # One epoch's validation recent_bleu4 = validate(val_loader=val_loader, encoder=encoder, decoder=decoder, criterion=criterion) # Check if there was an improvement is_best = recent_bleu4 > best_bleu4 best_bleu4 = max(recent_bleu4, best_bleu4) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_dir = 'output/show_attend_tell/checkpoints' if not os.path.exists(save_dir): os.makedirs(save_dir) save_checkpoint(checkpoint_base_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, recent_bleu4, is_best, save_dir='output/show_attend_tell/checkpoints')
def compare_visualize(split='test', html_path='visualizations/caption.html', visualize_count=100): dataset = TextureDescriptionData() word_encoder = WordEncoder() # cls_predictions = top_k_caption(top_k=5, model_type='cls', dataset=dataset, split=split) # with open('output/naive_classify/v1_35_ft2,4_fc512_tuneTrue/caption_top5_%s.json' % split, 'w') as f: # json.dump(cls_predictions, f) # tri_predictions = top_k_caption(top_k=5, model_type='tri', dataset=dataset, split=split) # with open('output/triplet_match/c34_bert_l2_s_lr0.00001/caption_top5_%s.json' % split, 'w') as f: # json.dump(tri_predictions, f) cls_predictions = json.load( open( 'output/naive_classify/v1_35_ft2,4_fc512_tuneTrue/caption_top5_%s.json' % split)) tri_predictions = json.load( open( 'output/triplet_match/c34_bert_l2_s_lr0.00001/caption_top5_%s.json' % split)) sat_predictions = json.load( open('output/show_attend_tell/results/pred_v2_last_beam5_%s.json' % split)) pred_dicts = [cls_predictions, tri_predictions, sat_predictions] img_pref = 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/' html_str = '''<!DOCTYPE html> <html lang="en"> <head> <title>Caption visualize</title> <link rel="stylesheet" href="caption_style.css"> </head> <body> <table> <col class="column-one"> <col class="column-two"> <col class="column-three"> <tr> <th style="text-align: center">Image</th> <th>Predicted captions</th> <th>Ground-truth descriptions</th> </tr> ''' for img_i, img_name in enumerate(dataset.img_splits[split]): gt_descs = dataset.img_data_dict[img_name]['descriptions'] gt_desc_str = '|'.join(gt_descs) gt_html_str = '' for ci, cap in enumerate(gt_descs): gt_html_str += '[%d] %s<br>\n' % (ci + 1, cap) pred_caps = [pred_dict[img_name][0] for pred_dict in pred_dicts] for ci, cap in enumerate(pred_caps): tokens = WordEncoder.tokenize(cap) for ti, t in enumerate(tokens): if t in gt_desc_str and len(t) > 1: tokens[ti] = '<span class="correct">%s</span>' % t pred_caps[ci] = word_encoder.detokenize(tokens) html_str += ''' <tr> <td> <img src={img_pref}{img_name} alt="{img_name}"> </td> <td> <span class="pred_name">Classifier top 5:</span><br> {pred0}<br> <span class="pred_name">Triplet top 5:</span><br> {pred1}<br> <span class="pred_name">Show-attend-tell:</span><br> {pred2}<br> </td> <td> {gt} </td> </tr> '''.format(img_pref=img_pref, img_name=img_name, pred0=pred_caps[0], pred1=pred_caps[1], pred2=pred_caps[2], gt=gt_html_str) if img_i >= visualize_count: break html_str += '</table>\n</body\n></html>' with open(html_path, 'w') as f: f.write(html_str) return
# checkpoint = 'output/show_attend_tell/checkpoints/BEST_checkpoint_v1_tuneResNetTrue.pth.tar' # model checkpoint checkpoint = 'output/show_attend_tell/checkpoints/checkpoint_v2_FastText.pth.tar' # model checkpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead # Load model checkpoint = torch.load(checkpoint) decoder = checkpoint['decoder'] decoder = decoder.to(device) decoder.eval() encoder = checkpoint['encoder'] encoder = encoder.to(device) encoder.eval() # Load word map (word2ix) word_encoder = WordEncoder() def predict(beam_size, img_dataset=None, split=split): # DataLoader if img_dataset is None: img_dataset = ImgOnlyDataset(split=split, transform=build_transforms(is_train=False)) loader = torch.utils.data.DataLoader(img_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) # TODO: Batched Beam Search # Therefore, do not use a batch_size greater than 1 - IMPORTANT! # Lists to store references (true captions), and hypothesis (prediction) for each image # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
def train(): from torch.utils.tensorboard import SummaryWriter # load configs parser = argparse.ArgumentParser(description="Triplet Matching Training") parser.add_argument('-c', '--config_file', default=None, help="path to config file") parser.add_argument('-o', '--opts', default=None, nargs=argparse.REMAINDER, help="Modify config options using the command-line. E.g. TRAIN.INIT_LR 0.01",) args = parser.parse_args() if args.config_file is not None: cfg.merge_from_file(args.config_file) if args.opts is not None: cfg.merge_from_list(args.opts) prepare(cfg) cfg.freeze() print(cfg.dump()) if not os.path.exists(cfg.OUTPUT_PATH): os.makedirs(cfg.OUTPUT_PATH) with open(os.path.join(cfg.OUTPUT_PATH, 'train.yml'), 'w') as f: f.write(cfg.dump()) # set random seed torch.manual_seed(cfg.RAND_SEED) np.random.seed(cfg.RAND_SEED) random.seed(cfg.RAND_SEED) # make data_loader, model, criterion, optimizer dataset = TripletTrainData(split=cfg.TRAIN_SPLIT, neg_img=cfg.LOSS.IMG_SENT_WEIGHTS[0] > 0, neg_lang=cfg.LOSS.IMG_SENT_WEIGHTS[1] > 0, lang_input=cfg.LANG_INPUT) data_loader = DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True) img_datset = ImgOnlyDataset(split=cfg.EVAL_SPLIT, transform=build_transforms(is_train=False), texture_dataset=dataset) eval_img_dataloader = DataLoader(img_datset, batch_size=1, shuffle=False) phrase_dataset = PhraseOnlyDataset(texture_dataset=dataset) eval_phrase_dataloader = DataLoader(phrase_dataset, batch_size=32, shuffle=False) word_encoder = WordEncoder() model: TripletMatch = TripletMatch(vec_dim=cfg.MODEL.VEC_DIM, neg_margin=cfg.LOSS.MARGIN, distance=cfg.MODEL.DISTANCE, img_feats=cfg.MODEL.IMG_FEATS, lang_encoder_method=cfg.MODEL.LANG_ENCODER, word_encoder=word_encoder) if cfg.INIT_WORD_EMBED != 'rand' and cfg.MODEL.LANG_ENCODER in ['mean', 'lstm']: word_emb = get_word_embed(word_encoder.word_list, cfg.INIT_WORD_EMBED) model.lang_embed.embeds.weight.data.copy_(torch.from_numpy(word_emb)) if len(cfg.LOAD_WEIGHTS) > 0: model.load_state_dict(torch.load(cfg.LOAD_WEIGHTS)) model.train() device = torch.device(cfg.DEVICE) model.to(device) if not cfg.TRAIN.TUNE_RESNET: model.resnet_encoder.requires_grad = False model.resnet_encoder.eval() if not cfg.TRAIN.TUNE_LANG_ENCODER: model.lang_embed.requires_grad = False model.lang_embed.eval() optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.TRAIN.INIT_LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY, betas=(cfg.TRAIN.ADAM.ALPHA, cfg.TRAIN.ADAM.BETA), eps=cfg.TRAIN.ADAM.EPSILON) # make tensorboard writer and dirs checkpoint_dir = os.path.join(cfg.OUTPUT_PATH, 'checkpoints') vis_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_%s_LAST' % cfg.EVAL_SPLIT) best_vis_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_%s_BEST' % cfg.EVAL_SPLIT) tb_dir = os.path.join(cfg.OUTPUT_PATH, 'tensorboard') tb_writer = SummaryWriter(tb_dir) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) if not os.path.exists(tb_dir): os.makedirs(tb_dir) # training loop step = 1 epoch = 1 epoch_float = 0.0 epoch_per_step = cfg.TRAIN.BATCH_SIZE * 1.0 / len(dataset) best_eval_metric = 0 best_metrics = None best_eval_count = 0 early_stop = False while epoch <= cfg.TRAIN.MAX_EPOCH and not early_stop: # for pos_imgs, pos_langs, neg_imgs, neg_langs in tqdm(data_loader, desc='TRAIN epoch %d' % epoch): for pos_imgs, pos_langs, neg_imgs, neg_langs in data_loader: pos_imgs = pos_imgs.to(device) if neg_imgs is not None and neg_imgs[0] is not None: neg_imgs = neg_imgs.to(device) verbose = step <= 5 or step % 50 == 0 neg_img_loss, neg_lang_loss = model(pos_imgs, pos_langs, neg_imgs, neg_langs, verbose=verbose) loss = cfg.LOSS.IMG_SENT_WEIGHTS[0] * neg_img_loss + cfg.LOSS.IMG_SENT_WEIGHTS[1] * neg_lang_loss loss /= sum(cfg.LOSS.IMG_SENT_WEIGHTS) optimizer.zero_grad() loss.backward() optimizer.step() lr = optimizer.param_groups[0]['lr'] tb_writer.add_scalar('train/loss', loss, step) tb_writer.add_scalar('train/neg_img_loss', neg_img_loss, step) tb_writer.add_scalar('train/neg_lang_loss', neg_lang_loss, step) tb_writer.add_scalar('train/lr', lr, step) if verbose: print('[%s] epoch-%d step-%d: loss %.4f (neg_img: %.4f, neg_lang: %.4f); lr %.1E' % (time.strftime('%m/%d %H:%M:%S'), epoch, step, loss, neg_img_loss, neg_lang_loss, lr)) # if epoch == 1 and step == 2: # debug eval # visualize_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_debug') # do_eval(model, eval_img_dataloader, eval_phrase_dataloader, device, split=cfg.EVAL_SPLIT, # visualize_path=visualize_path, add_to_summary_name=None) if epoch_float % cfg.TRAIN.EVAL_EVERY_EPOCH < epoch_per_step and epoch_float > 0: p2i_result, i2p_result = do_eval(model, eval_img_dataloader, eval_phrase_dataloader, device, split=cfg.EVAL_SPLIT, visualize_path=vis_path) for m, v in p2i_result.items(): tb_writer.add_scalar('eval_p2i/%s' % m, v, step) for m, v in i2p_result.items(): tb_writer.add_scalar('eval_i2p/%s' % m, v, step) eval_metric = p2i_result['mean_average_precision'] + i2p_result['mean_average_precision'] if eval_metric > best_eval_metric: print('EVAL: new best!') best_eval_metric = eval_metric best_metrics = (p2i_result, i2p_result) best_eval_count = 0 copy_tree(vis_path, best_vis_path, update=1) with open(os.path.join(checkpoint_dir, 'epoch_step.txt'), 'w') as f: f.write('BEST: epoch {}, step {}\n'.format(epoch, step)) torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'BEST_checkpoint.pth')) else: best_eval_count += 1 print('EVAL: since last best: %d' % best_eval_count) if epoch_float % cfg.TRAIN.CHECKPOINT_EVERY_EPOCH < epoch_per_step and epoch_float > 0: with open(os.path.join(checkpoint_dir, 'epoch_step.txt'), 'a') as f: f.write('LAST: epoch {}, step {}\n'.format(epoch, step)) torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'LAST_checkpoint.pth')) if best_eval_count % cfg.TRAIN.LR_DECAY_EVAL_COUNT == 0 and best_eval_count > 0: print('EVAL: lr decay triggered') for param_group in optimizer.param_groups: param_group['lr'] *= cfg.TRAIN.LR_DECAY_GAMMA if best_eval_count % cfg.TRAIN.EARLY_STOP_EVAL_COUNT == 0 and best_eval_count > 0: print('EVAL: early stop triggered') early_stop = True break model.train() if not cfg.TRAIN.TUNE_RESNET: model.resnet_encoder.eval() if not cfg.TRAIN.TUNE_LANG_ENCODER: model.lang_embed.eval() step += 1 epoch_float += epoch_per_step epoch += 1 tb_writer.close() if best_metrics is not None: exp_name = '%s:%s' % (cfg.OUTPUT_PATH, cfg.EVAL_SPLIT) log_to_summary(exp_name, best_metrics[0], best_metrics[1]) return best_metrics
def caption_visualize(pred_dicts=None, filter=False): if pred_dicts is None: pred_dicts = np.load('applications/synthetic_imgs/visualizations/results/caption.npy') syn_dataset = SyntheticData() word_encoder = WordEncoder() img_pref = '../modified_imgs/' html_str = '''<!DOCTYPE html> <html lang="en"> <head> <title>Caption visualize</title> <style> .correct { font-weight: bold; } .pred_name { color: ROYALBLUE; font-weight: bold; } img { width: 3cm } table { border-collapse: collapse; } tr { border-bottom: 1px solid lightgray; } </style> </head> <body> <table> <col class="column-one"> <col class="column-two"> <tr> <th style="text-align: center">Image</th> <th>Predicted captions</th> </tr> ''' for idx in range(len(syn_dataset)): pred_caps = [pred_dict[idx][0] for pred_dict in pred_dicts] img_i, c1_i, c2_i = syn_dataset.unravel_index(idx) if filter: good_cap = False c1 = syn_dataset.color_names[c1_i] c2 = syn_dataset.color_names[c2_i] for ci, cap in enumerate(pred_caps): if c1 in cap and c2 in cap: good_cap = True break if not good_cap: continue img_name = '%s_%s_%s.jpg' % (syn_dataset.img_names[img_i].split('.')[0], syn_dataset.color_names[c1_i], syn_dataset.color_names[c2_i]) gt_desc = syn_dataset.get_desc(img_i, c1_i, c2_i) for ci, cap in enumerate(pred_caps): tokens = WordEncoder.tokenize(cap) for ti, t in enumerate(tokens): if t in gt_desc and len(t) > 1: tokens[ti] = '<span class="correct">%s</span>' % t pred_caps[ci] = word_encoder.detokenize(tokens) html_str += ''' <tr> <td> <img src={img_pref}{img_name} alt="{img_name}"> </td> <td> <span class="pred_name">Synthetic Ground-truth Description:</span><br> {gt}<br> <span class="pred_name">Classifier top 5:</span><br> {pred0}<br> <span class="pred_name">Triplet top 5:</span><br> {pred1}<br> <span class="pred_name">Show-attend-tell:</span><br> {pred2}<br> </td> </tr> '''.format(img_pref=img_pref, img_name=img_name, pred0=pred_caps[0], pred1=pred_caps[1], pred2=pred_caps[2], gt=gt_desc) html_str += '</table>\n</body\n></html>' html_name = 'caption.html' if filter: html_name = 'caption_filtered.html' with open('applications/synthetic_imgs/visualizations/results/' + html_name, 'w') as f: f.write(html_str) return
class CaptionDataset(Dataset, TextureDescriptionData): """ A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches. """ def __init__(self, split, transform=None, word_encoder=None, is_train=True, caption_max_len=35): """ :param split: split, one of 'train', 'val', 'test' :param transform: image transform pipeline """ TextureDescriptionData.__init__(self, phid_format=None) self.transform = transform self.is_train = is_train self.caption_max_len = caption_max_len self.split = split assert self.split in ('train', 'val', 'test') self.word_encoder = word_encoder if self.word_encoder is None: self.word_encoder = WordEncoder() self.img_desc_ids = list() for img_i, img_name in enumerate(self.img_splits[split]): desc_num = len(self.img_data_dict[img_name]['descriptions']) self.img_desc_ids += [(img_i, desc_i) for desc_i in range(desc_num)] def __getitem__(self, i): img_i, desc_i = self.img_desc_ids[i] img_data = self.get_split_data(split=self.split, img_idx=img_i, load_img=True) img = img_data['image'] if self.transform is not None: img = self.transform(img) desc = img_data['descriptions'][desc_i] caption, caplen = self.word_encoder.encode( lang_input=desc, max_len=self.caption_max_len) caplen = torch.as_tensor([caplen], dtype=torch.long) caption = torch.as_tensor(caption, dtype=torch.long) if self.is_train: return img, caption, caplen else: # For validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score all_captions = list() mlen = 0 for desc in img_data['descriptions']: c, cl = self.word_encoder.encode(lang_input=desc, max_len=self.caption_max_len) all_captions.append(c) mlen = max(mlen, cl) all_captions_np = np.zeros((len(all_captions), mlen)) for ci, c in enumerate(all_captions): cl = min(len(c), mlen) all_captions_np[ci, :cl] = c[:cl] all_captions = torch.as_tensor(all_captions_np, dtype=torch.long) return img, caption, caplen, all_captions def __len__(self): return len(self.img_desc_ids)