def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='./data/',
                        help='path to datasets')
    parser.add_argument('--model_path',
                        default='./data/',
                        help='path to model')
    parser.add_argument('--split', default='test', help='val/test')
    parser.add_argument('--gpuid', default=0., type=str, help='gpuid')
    parser.add_argument('--fold5', action='store_true', help='fold5')
    opts = parser.parse_args()

    device_id = opts.gpuid
    print("use GPU:", device_id)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(device_id)
    device_id = 0
    torch.cuda.set_device(0)
    # load model and options
    checkpoint = torch.load(opts.model_path)
    opt = checkpoint['opt']
    opt.loss_verbose = False
    opt.split = opts.split
    opt.data_path = opts.data_path
    opt.fold5 = opts.fold5

    # load vocabulary used by the model
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # construct model
    model = SCAN(opt)
    model.cuda()
    model = nn.DataParallel(model)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = data.get_test_loader(opt.split, opt.data_name, vocab,
                                       opt.batch_size, opt.workers, opt)

    print(opt)
    print('Computing results...')

    evaluation.evalrank(model.module,
                        data_loader,
                        opt,
                        split=opt.split,
                        fold5=opt.fold5)
def load_model(model_path, device):
    # load model and options
    checkpoint = torch.load(model_path, map_location=device)
    opt = checkpoint['opt']

    # add because div_transform is not present in model
    d = vars(opt)
    d['div_transform'] = False

    # construct model
    model = SCAN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])
    return model, opt
Exemple #3
0
def load_model(model_path, device):
    # load model and options
    checkpoint = torch.load(model_path, map_location=device)
    opt = checkpoint['opt']

    # add because div_transform is not present in model
    # d = vars(opt)
    # d["layernorm"] = False
    # d['div_transform'] = False
    # d["net"] = "alex"
    # d["txt_enc"] = "basic"
    # d["diversity_loss"] = None

    # construct model
    model = SCAN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])
    return model, opt
def main():
    # Hyper Parameters
    
    opt = opts.parse_opt()

    device_id = opt.gpuid
    device_count = len(str(device_id).split(","))
    #assert device_count == 1 or device_count == 2
    print("use GPU:", device_id, "GPUs_count", device_count, flush=True)
    os.environ['CUDA_VISIBLE_DEVICES']=str(device_id)
    device_id = 0
    torch.cuda.set_device(0)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(
        opt.data_name, vocab, opt.batch_size, opt.workers, opt)

    # Construct the model
    model = SCAN(opt)
    model.cuda()
    model = nn.DataParallel(model)

     # Loss and Optimizer
    criterion = ContrastiveLoss(opt=opt, margin=opt.margin, max_violation=opt.max_violation)
    mse_criterion = nn.MSELoss(reduction="batchmean")
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate)

    # optionally resume from a checkpoint
    if not os.path.exists(opt.model_name):
        os.makedirs(opt.model_name)
    start_epoch = 0
    best_rsum = 0

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
                  .format(opt.resume, start_epoch, best_rsum))
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
    evalrank(model.module, val_loader, opt)

    print(opt, flush=True)
    
    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        message = "epoch: %d, model name: %s\n" % (epoch, opt.model_name)
        log_file = os.path.join(opt.logger_name, "performance.log")
        logging_func(log_file, message)
        print("model name: ", opt.model_name, flush=True)
        adjust_learning_rate(opt, optimizer, epoch)
        run_time = 0
        for i, (images, captions, lengths, masks, ids, _) in enumerate(train_loader):
            start_time = time.time()
            model.train()

            optimizer.zero_grad()

            if device_count != 1:
                images = images.repeat(device_count,1,1)

            score = model(images, captions, lengths, masks, ids)
            loss = criterion(score)

            loss.backward()
            if opt.grad_clip > 0:
                clip_grad_norm_(model.parameters(), opt.grad_clip)
            optimizer.step()
            run_time += time.time() - start_time
            # validate at every val_step
            if i % 100 == 0:
                log = "epoch: %d; batch: %d/%d; loss: %.4f; time: %.4f" % (epoch, 
                            i, len(train_loader), loss.data.item(), run_time / 100)
                print(log, flush=True)
                run_time = 0
            if (i + 1) % opt.val_step == 0:
                evalrank(model.module, val_loader, opt)

        print("-------- performance at epoch: %d --------" % (epoch))
        # evaluate on validation set
        rsum = evalrank(model.module, val_loader, opt)
        #rsum = -100
        filename = 'model_' + str(epoch) + '.pth.tar'
        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        save_checkpoint({
            'epoch': epoch + 1,
            'model': model.state_dict(),
            'best_rsum': best_rsum,
            'opt': opt,
        }, is_best, filename=filename, prefix=opt.model_name + '/')
Exemple #5
0
    def train(self):
        model = SCAN(self.params)
        model.apply(init_xavier)
        model.load_state_dict(torch.load('models/model_weights_5.t7'))
        loss_function = MarginLoss(self.params.margin)
        if torch.cuda.is_available():
            model = model.cuda()
            loss_function = loss_function.cuda()
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=self.params.lr,
                                     weight_decay=self.params.wdecay)
        try:
            prev_best = 0
            for epoch in range(self.params.num_epochs):
                iters = 1
                losses = []
                start_time = timer()
                num_of_mini_batches = len(
                    self.data_loader.train_ids) // self.params.batch_size
                for (caption, mask, image, neg_cap, neg_mask,
                     neg_image) in tqdm(self.data_loader.training_data_loader):

                    # Sample according to hard negative mining
                    caption, mask, image, neg_cap, neg_mask, neg_image = self.data_loader.hard_negative_mining(
                        model, caption, mask, image, neg_cap, neg_mask,
                        neg_image)
                    model.train()
                    optimizer.zero_grad()
                    # forward pass.
                    similarity = model(to_variable(caption), to_variable(mask),
                                       to_variable(image), False)
                    similarity_neg_1 = model(to_variable(neg_cap),
                                             to_variable(neg_mask),
                                             to_variable(image), False)
                    similarity_neg_2 = model(to_variable(caption),
                                             to_variable(mask),
                                             to_variable(neg_image), False)

                    # Compute the loss, gradients, and update the parameters by calling optimizer.step()
                    loss = loss_function(similarity, similarity_neg_1,
                                         similarity_neg_2)
                    loss.backward()
                    losses.append(loss.data.cpu().numpy())
                    if self.params.clip_value > 0:
                        torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      self.params.clip_value)
                    optimizer.step()

                    #                     sys.stdout.write("[%d/%d] :: Training Loss: %f   \r" % (
                    #                         iters, num_of_mini_batches, np.asscalar(np.mean(losses))))
                    #                     sys.stdout.flush()
                    iters += 1

                if epoch + 1 % self.params.step_size == 0:
                    optim_state = optimizer.state_dict()
                    optim_state['param_groups'][0]['lr'] = optim_state[
                        'param_groups'][0]['lr'] / self.params.gamma
                    optimizer.load_state_dict(optim_state)

                torch.save(
                    model.state_dict(), self.params.model_dir +
                    '/model_weights_{}.t7'.format(epoch + 1))

                # Calculate r@k after each epoch
                if (epoch + 1) % self.params.validate_every == 0:
                    r_at_1, r_at_5, r_at_10 = self.evaluator.recall(
                        model, is_test=False)

                    print(
                        "Epoch {} : Training Loss: {:.5f}, R@1 : {}, R@5 : {}, R@10 : {}, Time elapsed {:.2f} mins"
                        .format(epoch + 1, np.asscalar(np.mean(losses)),
                                r_at_1, r_at_5, r_at_10,
                                (timer() - start_time) / 60))
                    if r_at_1 > prev_best:
                        print("Recall at 1 increased....saving weights !!")
                        prev_best = r_at_1
                        torch.save(
                            model.state_dict(), self.params.model_dir +
                            'best_model_weights_{}.t7'.format(epoch + 1))
                else:
                    print("Epoch {} : Training Loss: {:.5f}".format(
                        epoch + 1, np.asscalar(np.mean(losses))))
        except KeyboardInterrupt:
            print("Interrupted.. saving model !!!")
            torch.save(model.state_dict(),
                       self.params.model_dir + '/model_weights_interrupt.t7')
Exemple #6
0
def evalrank(input_string,
             img_feature,
             how_many,
             model_path,
             data_path=None,
             split='dev',
             fold5=False,
             gpu_num=None):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    s_t = time.time()
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print(opt)
    print("%s seconds taken to load checkpoint" % (time.time() - s_t))
    if data_path is not None:
        opt.data_path = data_path

    # construct model
    model = SCAN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    # local dir
    # opt.vocab_path = '/home/ivy/hard2/scan_data/vocab'
    # docker dir
    opt.vocab_path = '/scan/SCAN/data/vocab'

    # load vocabulary used by the model
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)
    # print("opt.vocab_size ", opt.vocab_size)

    print("Loading npy file")
    start_time = time.time()
    # local dir
    # img_embs = np.load('/home/ivy/hard2/scan_out/img_embs.npy')
    img_embs = img_feature
    # docker dir
    #img_embs = np.load('/scan/SCAN/numpy_data/img_embs.npy')
    print("%s seconds takes to load npy file" % (time.time() - start_time))

    captions = []
    captions.append(str(input_string))
    tokens = nltk.tokenize.word_tokenize(str(captions).lower().decode('utf-8'))
    caption = []
    caption.append(vocab('<start>'))
    caption.extend([vocab(token) for token in tokens])
    caption.append(vocab('<end>'))
    target = []
    for batch in range(opt.batch_size):
        target.append(caption)
    target = torch.Tensor(target).long()

    print('Calculating results...')
    start_time = time.time()
    cap_embs, cap_len = encode_data(model, target, opt.batch_size)
    cap_lens = cap_len[0]
    print("%s seconds takes to calculate results" % (time.time() - start_time))
    print("Caption length with start and end index : ", cap_lens)
    print('Images: %d, Captions: %d' % (img_embs.shape[0], cap_embs.shape[0]))

    if not fold5:
        img_embs = np.array(img_embs)
        start = time.time()
        if opt.cross_attn == 't2i':
            sims = shard_xattn_t2i(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        elif opt.cross_attn == 'i2t':
            sims = shard_xattn_i2t(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        else:
            raise NotImplementedError
        end = time.time()
        print("calculate similarity time:", end - start)

        # top_10 = np.argsort(sims, axis=0)[-10:][::-1].flatten()
        top_n = np.argsort(sims, axis=0)[-(how_many):][::-1].flatten()
        final_result = list(top_n)

        # 5fold cross-validation, only for MSCOCO
    else:
        for i in range(10):
            if i < 9:
                img_embs_shard = img_embs[i *
                                          (img_embs.shape[0] // 10):(i + 1) *
                                          (img_embs.shape[0] // 10)]
            else:
                img_embs_shard = img_embs[i * (img_embs.shape[0] // 10):]
            cap_embs_shard = cap_embs
            cap_lens_shard = cap_lens
            start = time.time()
            if opt.cross_attn == 't2i':
                sims = shard_xattn_t2i(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            elif opt.cross_attn == 'i2t':
                sims = shard_xattn_i2t(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            else:
                raise NotImplementedError
            end = time.time()
            print("calculate similarity time:", end - start)

            top_10 = np.argsort(sims, axis=0)[-10:][::-1].flatten()

            print("Top 10 list for iteration #%d : " % (i + 1) +
                  str(top_10 + 5000 * i))

        #     r, rt0 = i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True)
        #     print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
        #     ri, rti0 = t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True)
        #     print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
        #
        #     if i == 0:
        #         rt, rti = rt0, rti0
        #     ar = (r[0] + r[1] + r[2]) / 3
        #     ari = (ri[0] + ri[1] + ri[2]) / 3
        #     rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        #     print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
        #     results += [list(r) + list(ri) + [ar, ari, rsum]]
        #
        # print("-----------------------------------")
        # print("Mean metrics: ")
        # mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
        # print("rsum: %.1f" % (mean_metrics[10] * 6))
        # print("Average i2t Recall: %.1f" % mean_metrics[11])
        # print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
        #       mean_metrics[:5])
        # print("Average t2i Recall: %.1f" % mean_metrics[12])
        # print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
        #       mean_metrics[5:10])

    # torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
    return final_result
Exemple #7
0
def evalrank(model_path, data_path=None, split='dev', fold5=False):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print(opt)
    if data_path is not None:
        opt.data_path = data_path

    # load vocabulary used by the model
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # construct model
    model = SCAN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size,
                                  opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs, cap_lens = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] / 5, cap_embs.shape[0]))

    if not fold5:
        # no cross-validation, full evaluation
        img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
        start = time.time()
        if opt.cross_attn == 't2i':
            sims = shard_xattn_t2i(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        elif opt.cross_attn == 'i2t':
            sims = shard_xattn_i2t(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        else:
            raise NotImplementedError
        end = time.time()
        print("calculate similarity time:", end - start)

        r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
        ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
        ar = (r[0] + r[1] + r[2]) / 3
        ari = (ri[0] + ri[1] + ri[2]) / 3
        rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        print("rsum: %.1f" % rsum)
        print("Average i2t Recall: %.1f" % ar)
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
        print("Average t2i Recall: %.1f" % ari)
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
    else:
        # 5fold cross-validation, only for MSCOCO
        results = []
        for i in range(5):
            img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
            cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
            cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
            start = time.time()
            if opt.cross_attn == 't2i':
                sims = shard_xattn_t2i(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            elif opt.cross_attn == 'i2t':
                sims = shard_xattn_i2t(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            else:
                raise NotImplementedError
            end = time.time()
            print("calculate similarity time:", end - start)

            r, rt0 = i2t(img_embs_shard,
                         cap_embs_shard,
                         cap_lens_shard,
                         sims,
                         return_ranks=True)
            print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
            ri, rti0 = t2i(img_embs_shard,
                           cap_embs_shard,
                           cap_lens_shard,
                           sims,
                           return_ranks=True)
            print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)

            if i == 0:
                rt, rti = rt0, rti0
            ar = (r[0] + r[1] + r[2]) / 3
            ari = (ri[0] + ri[1] + ri[2]) / 3
            rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
            print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
            results += [list(r) + list(ri) + [ar, ari, rsum]]

        print("-----------------------------------")
        print("Mean metrics: ")
        mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
        print("rsum: %.1f" % (mean_metrics[10] * 6))
        print("Average i2t Recall: %.1f" % mean_metrics[11])
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[:5])
        print("Average t2i Recall: %.1f" % mean_metrics[12])
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[5:10])

    torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
Exemple #8
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='./data/',
                        help='path to datasets')
    parser.add_argument('--data_name',
                        default='precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument('--vocab_path',
                        default='./vocab/',
                        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=30,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=15,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=10,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=500,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/runX/checkpoint',
                        help='Path to save the model.')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm',
                        action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument(
        '--raw_feature_norm',
        default="clipped_l2norm",
        help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax')
    parser.add_argument('--agg_func',
                        default="LogSumExp",
                        help='LogSumExp|Mean|Max|Sum')
    parser.add_argument('--cross_attn', default="t2i", help='t2i|i2t')
    parser.add_argument('--precomp_enc_type',
                        default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--bi_gru',
                        action='store_true',
                        help='Use bidirectional GRU.')
    parser.add_argument('--lambda_lse',
                        default=6.,
                        type=float,
                        help='LogSumExp temp.')
    parser.add_argument('--lambda_softmax',
                        default=9.,
                        type=float,
                        help='Attention softmax temperature.')
    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    model = SCAN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
Exemple #9
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_path',
        default='/data3/zhangyf/cross_modal_retrieval/SCAN/data',
        help='path to datasets')
    parser.add_argument('--data_name',
                        default='f30k_precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument(
        '--vocab_path',
        default='/data3/zhangyf/cross_modal_retrieval/SCAN/vocab/',
        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=20,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--decoder_dim',
                        default=512,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=10,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=4,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=30,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=500,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/runX/checkpoint',
                        help='Path to save the model.')
    parser.add_argument(
        '--resume',
        default=
        '/data3/zhangyf/cross_modal_retrieval/vsepp_next_train_12_31_f30k/run/coco_vse++_ft_128_f30k_next/model_best.pth.tar',
        type=str,
        metavar='PATH',
        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm',
                        action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument('--precomp_enc_type',
                        default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--reset_train',
                        action='store_true',
                        help='Ensure the training is always done in '
                        'train mode (Not recommended).')
    parser.add_argument('--finetune',
                        action='store_true',
                        help='Fine-tune the image encoder.')
    parser.add_argument('--cnn_type',
                        default='resnet152',
                        help="""The CNN used for image encoder
                        (e.g. vgg19, resnet152)""")
    parser.add_argument('--crop_size',
                        default=224,
                        type=int,
                        help='Size of an image crop as the CNN input.')

    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = pickle.load(
        open(os.path.join(opt.vocab_path, '%s_vocab.pkl' % opt.data_name),
             'rb'))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    model = SCAN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        bset_rsum = train(opt, train_loader, model, epoch, val_loader,
                          best_rsum)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
Exemple #10
0
        if (epoch % save_epoch == 0) or (epoch == training_epochs - 1):
            torch.save(scan.state_dict(),
                       '{}/scan_epoch_{}.pth'.format(exp, epoch))


data_manager = DataManager()
data_manager.prepare()

dae = DAE()
vae = VAE()
scan = SCAN()
if use_cuda:
    dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth'))
    vae.load_state_dict(torch.load('save/vae/vae_epoch_2999.pth'))
    scan.load_state_dict(torch.load('save/scan/scan_epoch_1499.pth'))
    dae, vae, scan = dae.cuda(), vae.cuda(), scan.cuda()
else:
    dae.load_state_dict(
        torch.load('save/dae/dae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    vae.load_state_dict(
        torch.load('save/vae/vae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    scan.load_state_dict(
        torch.load(exp + '/' + opt.load,
                   map_location=lambda storage, loc: storage))

if opt.train:
    scan_optimizer = optim.Adam(scan.parameters(), lr=1e-4, eps=1e-8)
    train_scan(dae, vae, scan, data_manager, scan_optimizer)
Exemple #11
0
def evalrank(model_path, data_path=None, split='dev', fold5=False):

    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print(opt)
    if data_path is not None:
        opt.data_path = data_path

    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    captions_w = np.load(opt.caption_np + 'caption_np.npy')
    captions_w = torch.from_numpy(captions_w)

    captions_w = captions_w.cuda()

    model = SCAN(opt, captions_w)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size,
                                  opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs, cap_lens = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] / 5, cap_embs.shape[0]))

    if not fold5:

        img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
        start = time.time()
        if opt.cross_attn == 't2i':
            sims = shard_xattn_t2i(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        elif opt.cross_attn == 'i2t':
            sims = shard_xattn_i2t(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        elif opt.cross_attn == 'all':
            sims, label = shard_xattn_all(model,
                                          img_embs,
                                          cap_embs,
                                          cap_lens,
                                          opt,
                                          shard_size=128)
        else:
            raise NotImplementedError
        end = time.time()
        print("calculate similarity time:", end - start)
        np.save('sim_stage1', sims)

        r, rt = i2t(label,
                    img_embs,
                    cap_embs,
                    cap_lens,
                    sims,
                    return_ranks=True)
        ri, rti = t2i(label,
                      img_embs,
                      cap_embs,
                      cap_lens,
                      sims,
                      return_ranks=True)
        ar = (r[0] + r[1] + r[2]) / 3
        ari = (ri[0] + ri[1] + ri[2]) / 3
        rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        print("rsum: %.1f" % rsum)
        print("Average i2t Recall: %.1f" % ar)
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
        print("Average t2i Recall: %.1f" % ari)
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
    else:

        results = []
        for i in range(5):
            img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
            cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
            cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
            start = time.time()
            if opt.cross_attn == 't2i':
                sims = shard_xattn_t2i(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            elif opt.cross_attn == 'i2t':
                sims = shard_xattn_i2t(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            else:
                raise NotImplementedError
            end = time.time()
            print("calculate similarity time:", end - start)

            r, rt0 = i2t(img_embs_shard,
                         cap_embs_shard,
                         cap_lens_shard,
                         sims,
                         return_ranks=True)
            print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
            ri, rti0 = t2i(img_embs_shard,
                           cap_embs_shard,
                           cap_lens_shard,
                           sims,
                           return_ranks=True)
            print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)

            if i == 0:
                rt, rti = rt0, rti0
            ar = (r[0] + r[1] + r[2]) / 3
            ari = (ri[0] + ri[1] + ri[2]) / 3
            rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
            print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
            results += [list(r) + list(ri) + [ar, ari, rsum]]

        print("-----------------------------------")
        print("Mean metrics: ")
        mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
        print("rsum: %.1f" % (mean_metrics[10] * 6))
        print("Average i2t Recall: %.1f" % mean_metrics[11])
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[:5])
        print("Average t2i Recall: %.1f" % mean_metrics[12])
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[5:10])

    torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
Exemple #12
0
def start_experiment(opt, seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    print("Let's use", torch.cuda.device_count(), "GPUs!")
    print("Number threads:", torch.get_num_threads())

    # Load Vocabulary Wrapper, create dictionary that can switch between ids and words
    vocab = deserialize_vocab("{}/{}/{}_vocab_{}.json".format(
        opt.vocab_path, opt.clothing, opt.data_name, opt.version))

    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data_ken.get_loaders(opt.data_name, vocab,
                                                    opt.batch_size,
                                                    opt.workers, opt)

    # Construct the model
    model = SCAN(opt)

    # save hyperparameters in file
    save_hyperparameters(opt.logger_name, opt)

    best_rsum = 0
    start_epoch = 0
    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch'] + 1
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)
        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)

        last_epoch = False
        if epoch == (opt.num_epochs - 1):
            last_epoch = True

        # only save when best epoch, or last epoch for further training
        if is_best or last_epoch:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model': model.state_dict(),
                    'best_rsum': best_rsum,
                    'opt': opt,
                    'Eiters': model.Eiters,
                },
                is_best,
                last_epoch,
                filename='checkpoint_{}.pth.tar'.format(epoch),
                prefix=opt.model_name + '/')
    return best_rsum
Exemple #13
0
def evalrank(model_path, data_path=None, split='dev', fold5=False):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print(opt)
    if data_path is not None:
        opt.data_path = data_path

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)

    # construct model
    model = SCAN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab,
                                  opt.batch_size, opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] / 5, cap_embs.shape[0]))

    if not fold5:
        # no cross-validation, full evaluation
        r, rt = i2t(img_embs, cap_embs, measure=opt.measure, return_ranks=True)
        ri, rti = t2i(img_embs, cap_embs,
                      measure=opt.measure, return_ranks=True)
        ar = (r[0] + r[1] + r[2]) / 3
        ari = (ri[0] + ri[1] + ri[2]) / 3
        rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        print("rsum: %.1f" % rsum)
        print("Average i2t Recall: %.1f" % ar)
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
        print("Average t2i Recall: %.1f" % ari)
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
    else:
        # 5fold cross-validation, only for MSCOCO
        results = []
        for i in range(5):
            r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000],
                         cap_embs[i * 5000:(i + 1) *
                                  5000], measure=opt.measure,
                         return_ranks=True)
            print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
            ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000],
                           cap_embs[i * 5000:(i + 1) *
                                    5000], measure=opt.measure,
                           return_ranks=True)
            if i == 0:
                rt, rti = rt0, rti0
            print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
            ar = (r[0] + r[1] + r[2]) / 3
            ari = (ri[0] + ri[1] + ri[2]) / 3
            rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
            print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
            results += [list(r) + list(ri) + [ar, ari, rsum]]

        print("-----------------------------------")
        print("Mean metrics: ")
        mean_metrics = tuple(numpy.array(results).mean(axis=0).flatten())
        print("rsum: %.1f" % (mean_metrics[10] * 6))
        print("Average i2t Recall: %.1f" % mean_metrics[11])
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
              mean_metrics[:5])
        print("Average t2i Recall: %.1f" % mean_metrics[12])
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
              mean_metrics[5:10])

    torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
Exemple #14
0
def evalrank(model_path,
             run,
             data_path=None,
             split='dev',
             fold5=False,
             vocab_path="../vocab/",
             change=False):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print(opt)

    # add because div_transform is not present in model
    # d = vars(opt)
    # d['tanh'] = True

    if data_path is not None:
        opt.data_path = data_path

    # load vocabulary used by the model
    vocab = deserialize_vocab("{}{}/{}_vocab_{}.json".format(
        vocab_path, opt.clothing, opt.data_name, opt.version))
    opt.vocab_size = len(vocab)
    print(opt.vocab_size)
    # construct model
    model = SCAN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    if change:
        opt.clothing = "dresses"

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size,
                                  opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs, cap_lens, freqs = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' % (img_embs.shape[0], cap_embs.shape[0]))

    t2i_switch = True
    if opt.cross_attn == 't2i':
        sims, attn = shard_xattn_t2i(img_embs,
                                     cap_embs,
                                     cap_lens,
                                     freqs,
                                     opt,
                                     shard_size=128)
    elif opt.cross_attn == 'i2t':
        sims, attn = shard_xattn_i2t(img_embs,
                                     cap_embs,
                                     cap_lens,
                                     freqs,
                                     opt,
                                     shard_size=128)
        t2i_switch = False
    else:
        raise NotImplementedError

    # r = (r1, r2, r5, medr, meanr), rt= (ranks, top1)
    r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
    ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    ari = (ri[0] + ri[1] + ri[2]) / 3
    rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
    print("rsum: %.1f" % rsum)
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f %.1f %.1f %.1f %.1f %.1f %.1f" % r)
    print("Average t2i Recall: %.1f" % ari)
    print("Text to image: %.1f %.1f %.1f %.1f %.1f %.1f %.1f" % ri)

    if opt.trans:
        save_dir = "plots_trans"
    else:
        save_dir = "plots_scan"

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    torch.save({
        'rt': rt,
        'rti': rti,
        "attn": attn,
        "t2i_switch": t2i_switch
    }, '{}/ranks_{}_{}.pth.tar'.format(save_dir, run, opt.version))
    return rt, rti, attn, r, ri
Exemple #15
0
def main(args):

    model_path = "{}/{}/seed1/checkpoint/{}".format(args.model_path, args.run,
                                                    args.checkpoint)

    # load model and options
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    opt = checkpoint['opt']

    # add because basic is not present in model
    d = vars(opt)
    d['basic'] = False

    run = args.run
    data_path = "{}/{}".format(args.data_path, args.data_name)
    nr_examples = args.nr_examples
    version = opt.version
    clothing = opt.clothing

    if opt.trans:
        plot_folder = "plots_trans"
    else:
        plot_folder = "plots_scan"

    plot_path = '{}/{}_{}'.format(plot_folder, version, run)
    caption_test_path = "{}/{}/data_captions_{}_test.txt".format(
        data_path, clothing, version)
    image_path = "{}".format(args.image_folder)
    vocab_path = "{}/{}".format(args.vocab_path, clothing)
    data_folder = "../data"

    if not os.path.exists(plot_path):
        os.makedirs(plot_path)

    # change image paths from lisa folders to local folders
    opt.data_path = data_folder
    opt.image_path = image_path
    opt.vocab_path = vocab_path
    print(opt)

    # construct model
    model = SCAN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    try:
        embs = torch.load("{}/embs/embs_{}_{}.pth.tar".format(
            plot_folder, run, version),
                          map_location=('cpu'))
        print("loading embeddings")
        img_embs = embs["img_embs"]
        cap_embs = embs["cap_embs"]
        cap_lens = embs["cap_lens"]
        freqs = embs["freqs"]
    except:
        print("Create embeddings")
        img_embs, cap_embs, cap_lens, freqs = get_embs(opt,
                                                       model,
                                                       run,
                                                       version,
                                                       data_path,
                                                       plot_folder,
                                                       vocab_path=vocab_path)

    print('Images: %d, Captions: %d' % (img_embs.shape[0], cap_embs.shape[0]))

    temp = torch.load("{}/ranks_{}_{}.pth.tar".format(plot_folder, run,
                                                      version),
                      map_location=('cpu'))

    rt = temp["rt"]
    rti = temp["rti"]
    attn = temp["attn"]
    t2i_switch = temp["t2i_switch"]

    r_i2t = calculate_r(rt[0], "i2t")
    r_t2i = calculate_r(rti[0], "t2i")

    top1_rt = rt[1]
    top1_rti = rti[1]

    if args.focus_subset:
        indx = get_indx_subset(caption_test_path, args.word_asked)
        rs_i2t = calculate_r(rt[0][indx], "i2t")
        rs_t2i = calculate_r(rti[0][indx], "t2i")
        print_result_subset(rs_i2t, r_i2t, "i2t", args.word_asked)
        print_result_subset(rs_t2i, r_t2i, "t2i", args.word_asked)
        rnd_indx = get_random_indx(nr_examples, len(indx))
        rnd = [indx[i] for i in rnd_indx]
    else:
        rnd = get_random_indx(nr_examples, len(top1_rt))

    # dictionary to turn test_ids to data_ids
    test_id2data = {}

    # find the caption and image with every id in the test file {caption_id : (image_id, caption)}
    with open(caption_test_path, newline='') as file:
        caption_reader = csv.reader(file, delimiter='\t')
        for i, line in enumerate(caption_reader):
            test_id2data[i] = (line[0], line[1])

    h5_images = get_h5_images(args.data_name, data_path)

    # get the matches
    matches_i2t = get_matches_i2t(top1_rt, test_id2data, nr_examples, rnd)
    matches_t2i = get_matches_t2i(top1_rti, test_id2data, nr_examples, rnd)

    # get id for file name
    unique_id = get_id(plot_path, "i2t", run)

    # plot image and caption together
    show_plots(matches_i2t, len(matches_i2t), "i2t", run, version, plot_path,
               args, clothing, h5_images, unique_id)
    show_plots(matches_t2i, len(matches_t2i), "t2i", run, version, plot_path,
               args, clothing, h5_images, unique_id)

    for i in range(len(rnd)):
        wanted_id = rnd[i]
        target_id = get_target_id(top1_rt, top1_rti, t2i_switch, wanted_id)

        attn = get_attn(img_embs, cap_embs, cap_lens, wanted_id, target_id,
                        opt, t2i_switch, freqs)

        if t2i_switch:
            words_caption = get_captions(test_id2data, wanted_id)
            image_segs = get_image_segs(target_id, test_id2data, args, opt,
                                        model, h5_images)
            match_t2i_viz(attn, wanted_id, target_id, test_id2data, run,
                          version, plot_path, clothing, words_caption,
                          image_segs)
        else:
            words_caption = get_captions(test_id2data, target_id)
            image_segs = get_image_segs(wanted_id, test_id2data, args, opt,
                                        model, h5_images)
            match_i2t_viz(attn, wanted_id, target_id, test_id2data, run,
                          version, plot_path, words_caption, image_segs)