def __init__(self, split='train', lang_input='phrase', neg_img=True, neg_lang=True): data.Dataset.__init__(self) TextureDescriptionData.__init__(self, phid_format='str') self.split = split self.lang_input = lang_input self.neg_img = neg_img self.neg_lang = neg_lang self.img_transform = build_transforms(is_train=False) self.pos_pairs = list() for img_i, img_name in enumerate(self.img_splits[self.split]): img_data = self.img_data_dict[img_name] if self.lang_input == 'phrase': self.pos_pairs += [(img_i, ph) for ph in img_data['phrase_ids']] elif self.lang_input == 'description': self.pos_pairs += [ (img_i, desc_idx) for desc_idx in range(len(img_data['descriptions'])) ] else: raise NotImplementedError return
def __init__(self, model=None, img_transform=None, device='cuda', trained_path='output/triplet_match/c34_bert_l2_s_lr0.00001', model_file='BEST_checkpoint.pth', split_to_phrases=False, dataset=None): if model is None: model, device = load_model(trained_path, model_file) model.eval() self.model = model self.device = device if img_transform is None: img_transform = build_transforms(is_train=False) self.img_transform = img_transform self.split_to_phrases = split_to_phrases if dataset is None: dataset = TextureDescriptionData(phid_format=None) self.dataset = dataset self.ph_vec_dict = None if self.split_to_phrases: ph_vecs = get_phrase_vecs(self.model, self.dataset) self.ph_vec_dict = {dataset.phrases[i]: ph_vecs[i] for i in range(len(dataset.phrases))} self.img_vecs = dict() return
def __init__(self, split='train', is_train=True, cached_resnet_feats=None): data.Dataset.__init__(self) TextureDescriptionData.__init__(self, phid_format='set') self.split = split self.is_train = is_train self.cached_resnet_feats = cached_resnet_feats self.use_cache = self.cached_resnet_feats is not None and len( self.cached_resnet_feats) > 0 self.transform = None if not self.use_cache: self.transform = build_transforms(is_train) print('PhraseClassifyDataset initialized.')
def __init__(self, model=None, img_transform=None, dataset=None, device='cuda'): if model is None: model, device = load_model() self.model = model self.device = device if img_transform is None: img_transform = build_transforms(is_train=False) self.img_transform = img_transform if dataset is None: dataset = TextureDescriptionData(phid_format=None) self.dataset = dataset self.ph_vecs = get_phrase_vecs(self.model, self.dataset) return
def __init__(self, model=None, img_transform=None, device='cuda', dataset=None): if dataset is None: dataset = TextureDescriptionData(phid_format=None) self.dataset = dataset if model is None: model, device = load_model(dataset=self.dataset) model.eval() self.model = model self.device = device if img_transform is None: img_transform = build_transforms(is_train=False) self.img_transform = img_transform self.img_ph_scores = dict() return
def __init__(self): self.img_names = [ 'cobwebbed/cobwebbed_0088.jpg', 'lined/lined_0084.jpg', 'polka-dotted/polka-dotted_0215.jpg', 'swirly/swirly_0135.jpg', 'chequered/chequered_0103.jpg', 'honeycombed/honeycombed_0059.jpg', 'dotted/dotted_0131.jpg', 'striped/striped_0035.jpg', 'zigzagged/zigzagged_0064.jpg', 'banded/banded_0138.jpg' ] self.colors = { 'white': (245, 245, 230), 'black': (20, 20, 20), 'brown': (120, 70, 20), 'green': (30, 160, 30), 'blue': (30, 100, 200), 'red': (120, 30, 30), 'yellow': (240, 240, 30), 'pink': (240, 140, 190), 'orange': (240, 140, 20), 'gray': (100, 100, 100), 'purple': (140, 30, 200) } self.color_names = list(self.colors.keys()) # 'silver': (192, 192, 192)} self.is_fore_back = [ True, True, True, True, False, False, True, False, False, False ] self.patterns = [ 'web', 'lines', 'polka-dots', 'swirls', 'squares', 'hexagon', 'dots', 'stripes', 'zigzagged', 'banded' ] self.color_tuples = list() for i in range(len(self.color_names)): for j in range(len(self.color_names)): if i != j: self.color_tuples.append((i, j)) self.img_transform = build_transforms(is_train=False)
def main(): """ Training and validation. """ global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, word_encoder word_encoder = WordEncoder() # Initialize / load checkpoint if checkpoint is None: print("Starting training from scratch!") decoder = DecoderWithAttention(attention_dim=attention_dim, embed_dim=emb_dim, decoder_dim=decoder_dim, vocab_size=len(word_encoder.word_list), dropout=dropout) decoder_optimizer = torch.optim.Adam(params=filter( lambda p: p.requires_grad, decoder.parameters()), lr=decoder_lr) encoder: Encoder = Encoder() encoder.fine_tune(fine_tune_encoder) encoder_optimizer = torch.optim.Adam( params=filter(lambda p: p.requires_grad, encoder.parameters()), lr=encoder_lr) if fine_tune_encoder else None if init_word_emb == 'fast_text': word_emb = get_word_embed(word_encoder.word_list, 'fast_text') decoder.embedding.weight.data.copy_(torch.from_numpy(word_emb)) else: print("Loading from checkpoint " + checkpoint) checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] best_bleu4 = checkpoint['bleu-4'] decoder = checkpoint['decoder'] decoder_optimizer = checkpoint['decoder_optimizer'] encoder = checkpoint['encoder'] encoder_optimizer = checkpoint['encoder_optimizer'] if fine_tune_encoder is True and encoder_optimizer is None: encoder.fine_tune(fine_tune_encoder) encoder_optimizer = torch.optim.Adam(params=filter( lambda p: p.requires_grad, encoder.parameters()), lr=encoder_lr) # Move to GPU, if available decoder = decoder.to(device) encoder = encoder.to(device) # Loss function criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_loader = torch.utils.data.DataLoader(CaptionDataset( split='train', is_train=True, word_encoder=word_encoder, transform=build_transforms(is_train=fine_tune_encoder)), batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(CaptionDataset( split='val', is_train=False, word_encoder=word_encoder, transform=build_transforms(is_train=False)), batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True, collate_fn=caption_collate) print('ready to train.') # Epochs for epoch in range(start_epoch, epochs): # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20 if epochs_since_improvement == 20: break if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0: adjust_learning_rate(decoder_optimizer, 0.8) if fine_tune_encoder: adjust_learning_rate(encoder_optimizer, 0.8) # One epoch's training train(train_loader=train_loader, encoder=encoder, decoder=decoder, criterion=criterion, encoder_optimizer=encoder_optimizer, decoder_optimizer=decoder_optimizer, epoch=epoch) # One epoch's validation recent_bleu4 = validate(val_loader=val_loader, encoder=encoder, decoder=decoder, criterion=criterion) # Check if there was an improvement is_best = recent_bleu4 > best_bleu4 best_bleu4 = max(recent_bleu4, best_bleu4) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_dir = 'output/show_attend_tell/checkpoints' if not os.path.exists(save_dir): os.makedirs(save_dir) save_checkpoint(checkpoint_base_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, recent_bleu4, is_best, save_dir='output/show_attend_tell/checkpoints')
def __init__(self, split=None, data_path='other_datasets/CUB_200_2011', val_ratio=0.1): self.data_path = data_path self.split = split self.val_ratio = val_ratio if split not in ['train', 'val', 'test']: self.split = None self.img_transform = build_transforms(is_train=False) self.img_splits = {'train': list(), 'test': list()} with open(os.path.join(data_path, 'train_test_split.txt'), 'r') as f: for line in f: img_id, is_train = line.split(' ') img_id = int(img_id.strip()) - 1 is_train = int(is_train.strip()) if is_train: self.img_splits['train'].append(img_id) else: self.img_splits['test'].append(img_id) if val_ratio > 0: val_len = int(len(self.img_splits['train']) * val_ratio) np.random.seed(0) self.img_splits['val'] = np.random.choice(self.img_splits['train'], val_len, replace=False) train_ids = [ i for i in self.img_splits['train'] if i not in self.img_splits['val'] ] self.img_splits['train'] = train_ids self.class_names = [] with open(os.path.join(data_path, 'classes.txt'), 'r') as f: for li, line in enumerate(f): cls_id, name = line.split(' ') cls_id = int(cls_id.strip()) - 1 assert cls_id == li name = name.strip() self.class_names.append(name) self.att_names = [] with open(os.path.join(data_path, 'attributes/attributes.txt'), 'r') as f: for li, line in enumerate(f): att_id, name = line.split(' ') att_id = int(att_id.strip()) - 1 assert att_id == li name = name.strip() self.att_names.append(name) self.att_types = dict() self.att_types['shape'] = [ i for i, n in enumerate(self.att_names) if '_shape:' in n or 'length:' in n or 'size:' in n ] for type in ['color', 'pattern']: self.att_types[type] = [ i for i, n in enumerate(self.att_names) if '_%s:' % type in n ] self.img_data_list = [] with open(os.path.join(data_path, 'images.txt'), 'r') as f: for li, line in enumerate(f): img_id, name = line.split(' ') img_id = int(img_id.strip()) - 1 assert img_id == li name = name.strip() self.img_data_list.append({'img_name': name}) with open(os.path.join(data_path, 'image_class_labels.txt'), 'r') as f: for line in f: img_id, cls = line.split(' ') img_id = int(img_id.strip()) - 1 cls = int(cls.strip()) - 1 self.img_data_list[img_id]['class_label'] = cls with open(os.path.join(data_path, 'bounding_boxes.txt'), 'r') as f: for line in f: numbers = line.split(' ') numbers = [float(n.strip()) for n in numbers] img_id = int(numbers[0]) - 1 xywh = numbers[1:] self.img_data_list[img_id]['box'] = xywh for img_data in self.img_data_list: img_data['att_labels'] = np.zeros(len(self.att_names)) self.gt_att_labels = np.zeros( (len(self.img_data_list), len(self.att_names))) with open( os.path.join(data_path, 'attributes/image_attribute_labels.txt'), 'r') as f: # <image_id> <attribute_id> <is_present> <certainty_id> <time> for line in f: numbers = line.split(' ') numbers = [int(n.strip()) for n in numbers[:3]] img_id = numbers[0] - 1 att_id = numbers[1] - 1 is_present = numbers[2] if is_present: self.img_data_list[img_id]['att_labels'][att_id] = 1 self.gt_att_labels[img_id, att_id] = 1 # TODO non-localized self.class_att_labels = np.zeros((200, 312)) with open( os.path.join( data_path, 'attributes/class_attribute_labels_continuous.txt'), 'r') as f: # 200 lines and 312 space-separated columns for li, line in enumerate(f): numbers = line.split(' ') numbers = np.array([float(n.strip()) for n in numbers]) self.class_att_labels[li] = numbers print('CUB dataset ready.') return
def predict(beam_size, img_dataset=None, split=split): # DataLoader if img_dataset is None: img_dataset = ImgOnlyDataset(split=split, transform=build_transforms(is_train=False)) loader = torch.utils.data.DataLoader(img_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) # TODO: Batched Beam Search # Therefore, do not use a batch_size greater than 1 - IMPORTANT! # Lists to store references (true captions), and hypothesis (prediction) for each image # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] predictions = dict() # For each image with torch.no_grad(): for bi, (img_name, image) in enumerate(tqdm(loader, desc="PREDICTING AT BEAM SIZE " + str(beam_size))): img_name = img_name[0] k = beam_size # Move to GPU device, if available image = image.to(device) # (1, 3, 256, 256) # Encode encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) enc_image_size = encoder_out.size(1) encoder_dim = encoder_out.size(3) # Flatten encoding encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) num_pixels = encoder_out.size(1) # We'll treat the problem as having a batch size of k encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) # Tensor to store top k previous words at each step; now they're just <start> k_prev_words = torch.as_tensor([[word_encoder.word_map['<start>']]] * k, dtype=torch.long).to(device) #(k,1) # Tensor to store top k sequences; now they're just <start> seqs = k_prev_words # (k, 1) # Tensor to store top k sequences' scores; now they're just 0 top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) # Lists to store completed sequences and scores complete_seqs = list() complete_seqs_scores = list() # Start decoding step = 1 h, c = decoder.init_hidden_state(encoder_out) # s is a number less than or equal to k, because sequences are removed from this process once they hit <end> while True: embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) awe, _ = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim) awe = gate * awe h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) scores = decoder.fc(h) # (s, vocab_size) scores = F.log_softmax(scores, dim=1) # Add scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) # For the first step, all k points will have the same scores (since same k previous words, h, c) if step == 1: top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) else: # Unroll and find top scores, and their unrolled indices top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) # Convert unrolled indices to actual indices of scores vocab_size = len(word_encoder.word_list) prev_word_inds = top_k_words / vocab_size # (s) next_word_inds = top_k_words % vocab_size # (s) # print("Top k scores") # print(top_k_scores) # print(top_k_words) # Add new words to sequences # print("Prev word inds") # print(prev_word_inds) # print("Next word idns") # print(next_word_inds) seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) # Which sequences are incomplete (didn't reach <end>)? incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if next_word != word_encoder.word_map['<end>']] complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) # Set aside complete sequences if len(complete_inds) > 0: complete_seqs.extend(seqs[complete_inds].tolist()) complete_seqs_scores.extend(top_k_scores[complete_inds].cpu().numpy()) # k -= len(complete_inds) # reduce beam length accordingly # Proceed with incomplete sequences # if k == 0: # break if len(incomplete_inds) == 0: break seqs = seqs[incomplete_inds] h = h[prev_word_inds[incomplete_inds]] c = c[prev_word_inds[incomplete_inds]] encoder_out = encoder_out[prev_word_inds[incomplete_inds]] top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) # Break if things have been going on too long if step > 50: break step += 1 # print('seqs:', seqs) # print('complete lens:', [len(s) for s in complete_seqs]) # print("Next step\n") # i = complete_seqs_scores.index(max(complete_seqs_scores)) # seq = complete_seqs[i] if len(complete_seqs_scores) == 0: complete_seqs_scores = top_k_scores.squeeze().cpu().numpy() complete_seqs = seqs.cpu().tolist() sorted_idxs = np.argsort(np.asarray(complete_seqs_scores) * -1.0) if len(sorted_idxs) > beam_size: sorted_idxs = sorted_idxs[: beam_size] sorted_seqs = [complete_seqs[i] for i in sorted_idxs] if bi == 0: # debug # best_i = complete_seqs_scores.index(max(complete_seqs_scores)) # best_seq = complete_seqs[best_i] # print('best:', best_seq, complete_seqs_scores[best_i]) print('top k:') for i, idx in enumerate(sorted_idxs): ignore_wids = [word_encoder.word_map[w] for w in ['<start>', '<end>', '<pad>']] seq = sorted_seqs[i] tokens = [word_encoder.word_list[wid] for wid in seq if wid not in ignore_wids] caption = word_encoder.detokenize(tokens) print(caption, complete_seqs_scores[idx]) predictions[img_name] = list() ignore_wids = [word_encoder.word_map[w] for w in ['<start>', '<end>', '<pad>']] for seq in sorted_seqs: tokens = [word_encoder.word_list[wid] for wid in seq if wid not in ignore_wids] # caption = ' '.join(tokens) caption = word_encoder.detokenize(tokens) predictions[img_name].append(caption) if len(predictions[img_name]) < beam_size: predictions[img_name] += [''] * (beam_size - len(predictions[img_name])) return predictions
def train(): from torch.utils.tensorboard import SummaryWriter # load configs parser = argparse.ArgumentParser(description="Triplet Matching Training") parser.add_argument('-c', '--config_file', default=None, help="path to config file") parser.add_argument('-o', '--opts', default=None, nargs=argparse.REMAINDER, help="Modify config options using the command-line. E.g. TRAIN.INIT_LR 0.01",) args = parser.parse_args() if args.config_file is not None: cfg.merge_from_file(args.config_file) if args.opts is not None: cfg.merge_from_list(args.opts) prepare(cfg) cfg.freeze() print(cfg.dump()) if not os.path.exists(cfg.OUTPUT_PATH): os.makedirs(cfg.OUTPUT_PATH) with open(os.path.join(cfg.OUTPUT_PATH, 'train.yml'), 'w') as f: f.write(cfg.dump()) # set random seed torch.manual_seed(cfg.RAND_SEED) np.random.seed(cfg.RAND_SEED) random.seed(cfg.RAND_SEED) # make data_loader, model, criterion, optimizer dataset = TripletTrainData(split=cfg.TRAIN_SPLIT, neg_img=cfg.LOSS.IMG_SENT_WEIGHTS[0] > 0, neg_lang=cfg.LOSS.IMG_SENT_WEIGHTS[1] > 0, lang_input=cfg.LANG_INPUT) data_loader = DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True) img_datset = ImgOnlyDataset(split=cfg.EVAL_SPLIT, transform=build_transforms(is_train=False), texture_dataset=dataset) eval_img_dataloader = DataLoader(img_datset, batch_size=1, shuffle=False) phrase_dataset = PhraseOnlyDataset(texture_dataset=dataset) eval_phrase_dataloader = DataLoader(phrase_dataset, batch_size=32, shuffle=False) word_encoder = WordEncoder() model: TripletMatch = TripletMatch(vec_dim=cfg.MODEL.VEC_DIM, neg_margin=cfg.LOSS.MARGIN, distance=cfg.MODEL.DISTANCE, img_feats=cfg.MODEL.IMG_FEATS, lang_encoder_method=cfg.MODEL.LANG_ENCODER, word_encoder=word_encoder) if cfg.INIT_WORD_EMBED != 'rand' and cfg.MODEL.LANG_ENCODER in ['mean', 'lstm']: word_emb = get_word_embed(word_encoder.word_list, cfg.INIT_WORD_EMBED) model.lang_embed.embeds.weight.data.copy_(torch.from_numpy(word_emb)) if len(cfg.LOAD_WEIGHTS) > 0: model.load_state_dict(torch.load(cfg.LOAD_WEIGHTS)) model.train() device = torch.device(cfg.DEVICE) model.to(device) if not cfg.TRAIN.TUNE_RESNET: model.resnet_encoder.requires_grad = False model.resnet_encoder.eval() if not cfg.TRAIN.TUNE_LANG_ENCODER: model.lang_embed.requires_grad = False model.lang_embed.eval() optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.TRAIN.INIT_LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY, betas=(cfg.TRAIN.ADAM.ALPHA, cfg.TRAIN.ADAM.BETA), eps=cfg.TRAIN.ADAM.EPSILON) # make tensorboard writer and dirs checkpoint_dir = os.path.join(cfg.OUTPUT_PATH, 'checkpoints') vis_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_%s_LAST' % cfg.EVAL_SPLIT) best_vis_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_%s_BEST' % cfg.EVAL_SPLIT) tb_dir = os.path.join(cfg.OUTPUT_PATH, 'tensorboard') tb_writer = SummaryWriter(tb_dir) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) if not os.path.exists(tb_dir): os.makedirs(tb_dir) # training loop step = 1 epoch = 1 epoch_float = 0.0 epoch_per_step = cfg.TRAIN.BATCH_SIZE * 1.0 / len(dataset) best_eval_metric = 0 best_metrics = None best_eval_count = 0 early_stop = False while epoch <= cfg.TRAIN.MAX_EPOCH and not early_stop: # for pos_imgs, pos_langs, neg_imgs, neg_langs in tqdm(data_loader, desc='TRAIN epoch %d' % epoch): for pos_imgs, pos_langs, neg_imgs, neg_langs in data_loader: pos_imgs = pos_imgs.to(device) if neg_imgs is not None and neg_imgs[0] is not None: neg_imgs = neg_imgs.to(device) verbose = step <= 5 or step % 50 == 0 neg_img_loss, neg_lang_loss = model(pos_imgs, pos_langs, neg_imgs, neg_langs, verbose=verbose) loss = cfg.LOSS.IMG_SENT_WEIGHTS[0] * neg_img_loss + cfg.LOSS.IMG_SENT_WEIGHTS[1] * neg_lang_loss loss /= sum(cfg.LOSS.IMG_SENT_WEIGHTS) optimizer.zero_grad() loss.backward() optimizer.step() lr = optimizer.param_groups[0]['lr'] tb_writer.add_scalar('train/loss', loss, step) tb_writer.add_scalar('train/neg_img_loss', neg_img_loss, step) tb_writer.add_scalar('train/neg_lang_loss', neg_lang_loss, step) tb_writer.add_scalar('train/lr', lr, step) if verbose: print('[%s] epoch-%d step-%d: loss %.4f (neg_img: %.4f, neg_lang: %.4f); lr %.1E' % (time.strftime('%m/%d %H:%M:%S'), epoch, step, loss, neg_img_loss, neg_lang_loss, lr)) # if epoch == 1 and step == 2: # debug eval # visualize_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_debug') # do_eval(model, eval_img_dataloader, eval_phrase_dataloader, device, split=cfg.EVAL_SPLIT, # visualize_path=visualize_path, add_to_summary_name=None) if epoch_float % cfg.TRAIN.EVAL_EVERY_EPOCH < epoch_per_step and epoch_float > 0: p2i_result, i2p_result = do_eval(model, eval_img_dataloader, eval_phrase_dataloader, device, split=cfg.EVAL_SPLIT, visualize_path=vis_path) for m, v in p2i_result.items(): tb_writer.add_scalar('eval_p2i/%s' % m, v, step) for m, v in i2p_result.items(): tb_writer.add_scalar('eval_i2p/%s' % m, v, step) eval_metric = p2i_result['mean_average_precision'] + i2p_result['mean_average_precision'] if eval_metric > best_eval_metric: print('EVAL: new best!') best_eval_metric = eval_metric best_metrics = (p2i_result, i2p_result) best_eval_count = 0 copy_tree(vis_path, best_vis_path, update=1) with open(os.path.join(checkpoint_dir, 'epoch_step.txt'), 'w') as f: f.write('BEST: epoch {}, step {}\n'.format(epoch, step)) torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'BEST_checkpoint.pth')) else: best_eval_count += 1 print('EVAL: since last best: %d' % best_eval_count) if epoch_float % cfg.TRAIN.CHECKPOINT_EVERY_EPOCH < epoch_per_step and epoch_float > 0: with open(os.path.join(checkpoint_dir, 'epoch_step.txt'), 'a') as f: f.write('LAST: epoch {}, step {}\n'.format(epoch, step)) torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'LAST_checkpoint.pth')) if best_eval_count % cfg.TRAIN.LR_DECAY_EVAL_COUNT == 0 and best_eval_count > 0: print('EVAL: lr decay triggered') for param_group in optimizer.param_groups: param_group['lr'] *= cfg.TRAIN.LR_DECAY_GAMMA if best_eval_count % cfg.TRAIN.EARLY_STOP_EVAL_COUNT == 0 and best_eval_count > 0: print('EVAL: early stop triggered') early_stop = True break model.train() if not cfg.TRAIN.TUNE_RESNET: model.resnet_encoder.eval() if not cfg.TRAIN.TUNE_LANG_ENCODER: model.lang_embed.eval() step += 1 epoch_float += epoch_per_step epoch += 1 tb_writer.close() if best_metrics is not None: exp_name = '%s:%s' % (cfg.OUTPUT_PATH, cfg.EVAL_SPLIT) log_to_summary(exp_name, best_metrics[0], best_metrics[1]) return best_metrics
def main_eval(): # load configs parser = argparse.ArgumentParser( description="Triplet (phrase) retrieval evaluation") parser.add_argument('-p', '--trained_path', help="path to trained model (where there is cfg file)", default='output/triplet_match/c34_bert_l2_s_lr0.00001') parser.add_argument('-m', '--model_file', help='file name of the cached model ', default='BEST_checkpoint.pth') parser.add_argument('-o', '--opts', default=None, nargs=argparse.REMAINDER, help="e.g. EVAL_SPLIT test") args = parser.parse_args() cfg.merge_from_file(os.path.join(args.trained_path, 'train.yml')) if args.opts is not None: cfg.merge_from_list(args.opts) # set random seed torch.manual_seed(cfg.RAND_SEED) np.random.seed(cfg.RAND_SEED) random.seed(cfg.RAND_SEED) dataset = TextureDescriptionData(phid_format=None) img_dataset = ImgOnlyDataset(split=cfg.EVAL_SPLIT, transform=build_transforms(is_train=False), texture_dataset=dataset) img_dataloader = DataLoader(img_dataset, batch_size=1, shuffle=False) phrase_dataset = PhraseOnlyDataset(texture_dataset=dataset) phrase_dataloader = DataLoader(phrase_dataset, batch_size=32, shuffle=False) model: TripletMatch = TripletMatch( vec_dim=cfg.MODEL.VEC_DIM, neg_margin=cfg.LOSS.MARGIN, distance=cfg.MODEL.DISTANCE, img_feats=cfg.MODEL.IMG_FEATS, lang_encoder_method=cfg.MODEL.LANG_ENCODER) model_path = os.path.join(args.trained_path, 'checkpoints', args.model_file) model.load_state_dict(torch.load(model_path)) device = torch.device(cfg.DEVICE) model.to(device) do_eval(model, img_dataloader, phrase_dataloader, device, split=cfg.EVAL_SPLIT, visualize_path=os.path.join(args.trained_path, 'eval_visualize_%s' % cfg.EVAL_SPLIT), add_to_summary_name=model_path + ':' + cfg.EVAL_SPLIT)