Exemple #1
0
def test(**kwargs):
    opt.parse(kwargs)

    train_L, query_L, retrieval_L, train_x, query_x, retrieval_x, train_y, query_y, retrieval_y = load_data(
        opt.data_path)

    y_dim = query_y.shape[1]

    print('...loading and splitting data finish')

    img_model = ImgModule(opt.bit)
    txt_model = TxtModule(y_dim, opt.bit)

    if opt.load_img_path:
        img_model.load(opt.load_img_path)

    if opt.load_txt_path:
        txt_model.load(opt.load_txt_path)

    if opt.use_gpu:
        img_model = img_model.cuda()
        txt_model = txt_model.cuda()

    qBX = generate_image_code(img_model, query_x, opt.bit)
    qBY = generate_text_code(txt_model, query_y, opt.bit)
    rBX = generate_image_code(img_model, retrieval_x, opt.bit)
    rBY = generate_text_code(txt_model, retrieval_y, opt.bit)

    if opt.use_gpu:
        query_L = query_L.cuda()
        retrieval_L = retrieval_L.cuda()

    mapi2t = calc_map_k(qBX, rBY, query_L, retrieval_L)
    mapt2i = calc_map_k(qBY, rBX, query_L, retrieval_L)
    print('...test MAP: MAP(i->t): %3.3f, MAP(t->i): %3.3f' % (mapi2t, mapt2i))
Exemple #2
0
    def calc_maps_k(self, qBX, qBY, rBX, rBY, qLX, qLY, rLX, rLY, k):
        """
        Calculate MAPs, in regards to K

        :param: qBX: query hashes, modality X
        :param: qBY: query hashes, modality Y
        :param: rBX: response hashes, modality X
        :param: rBY: response hashes, modality Y
        :param: qLX: query labels, modality X
        :param: qLY: query labels, modality Y
        :param: rLX: response labels, modality X
        :param: rLY: response labels, modality Y
        :param: k: k

        :returns: MAPs
        """
        mapi2t = calc_map_k(qBX, rBY, qLX, rLY, k)
        mapt2i = calc_map_k(qBY, rBX, qLY, rLX, k)
        mapi2i = calc_map_k(qBX, rBX, qLX, rLX, k)
        mapt2t = calc_map_k(qBY, rBY, qLY, rLY, k)

        avg = (mapi2t.item() + mapt2i.item() + mapi2i.item() +
               mapt2t.item()) * 0.25

        mapi2t, mapt2i, mapi2i, mapt2t, mapavg = mapi2t.item(), mapt2i.item(
        ), mapi2i.item(), mapt2t.item(), avg

        s = 'Valid: mAP@{}, avg: {:3.3f}, i->t: {:3.3f}, t->i: {:3.3f}, i->i: {:3.3f}, t->t: {:3.3f}'
        self.logger.info(s.format(k, mapavg, mapi2t, mapt2i, mapi2i, mapt2t))

        return mapi2t, mapt2i, mapi2i, mapt2t, mapavg
Exemple #3
0
def valid(img_model, txt_model, query_x, retrieval_x, query_y, retrieval_y, query_L, retrieval_L):
    qBX = generate_image_code(img_model, query_x, opt.bit)
    qBY = generate_text_code(txt_model, query_y, opt.bit)
    rBX = generate_image_code(img_model, retrieval_x, opt.bit)
    rBY = generate_text_code(txt_model, retrieval_y, opt.bit)

    mapi2t = calc_map_k(qBX, rBY, query_L, retrieval_L)
    mapt2i = calc_map_k(qBY, rBX, query_L, retrieval_L)
    return mapi2t, mapt2i
Exemple #4
0
def valid(img_model,
          txt_model,
          img_memory_net,
          text_memory_net,
          img_centroids,
          text_centoids,
          query_x,
          retrieval_x,
          query_y,
          retrieval_y,
          query_L,
          retrieval_L,
          save=False):
    qBX = generate_image_code(img_model, img_memory_net, query_x,
                              img_centroids, opt.bit)
    qBY = generate_text_code(txt_model, text_memory_net, query_y,
                             text_centoids, opt.bit)
    rBX = generate_image_code(img_model, img_memory_net, retrieval_x,
                              img_centroids, opt.bit)
    rBY = generate_text_code(txt_model, text_memory_net, retrieval_y,
                             text_centoids, opt.bit)

    dir_name = 'hashcodes' + time.strftime("%Y_%m_%d_%H%M", time.localtime())
    os.makedirs(dir_name)
    if save:
        np.save(os.path.join(dir_name, 'qBX.npy'), qBX.cpu().numpy())
        np.save(os.path.join(dir_name, 'qBY.npy'), qBY.cpu().numpy())
        np.save(os.path.join(dir_name, 'rBX.npy'), rBX.cpu().numpy())
        np.save(os.path.join(dir_name, 'rBY.npy'), rBY.cpu().numpy())
        print("Hash codes saved.\n")

    num_class = query_L.shape[1]
    num_test = query_L.shape[0]
    index = np.arange(num_test)
    mlist_i2t = []
    mlist_t2i = []

    for i in range(num_class):
        ind = index[query_L[:, i] == 1]
        class_mapi2t = calc_map_k(qBX[ind, :], rBY, query_L[ind, :],
                                  retrieval_L)
        class_mapt2i = calc_map_k(qBY[ind, :], rBX, query_L[ind, :],
                                  retrieval_L)
        mlist_i2t.append(class_mapi2t)
        mlist_t2i.append(class_mapt2i)

    f = open(os.path.join(dir_name, 'maplist.txt'), "w")
    f.write("i2t: " + str(mlist_i2t))
    f.write("\n")
    f.write("t2i: " + str(mlist_t2i))

    mapi2t = calc_map_k(qBX, rBY, query_L, retrieval_L)
    mapt2i = calc_map_k(qBY, rBX, query_L, retrieval_L)
    return mapi2t, mapt2i
Exemple #5
0
def test(**kwargs):
    opt.parse(kwargs)

    if opt.device is not None:
        opt.device = torch.device(opt.device)
    elif opt.gpus:
        opt.device = torch.device(0)
    else:
        opt.device = torch.device('cpu')

    pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    generator = GEN(opt.dropout,
                    opt.image_dim,
                    opt.text_dim,
                    opt.hidden_dim,
                    opt.bit,
                    pretrain_model=pretrain_model).to(opt.device)

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    load_model(generator, path)

    generator.eval()

    images, tags, labels = load_data(opt.data_path, opt.dataset)

    i_query_data = Dataset(opt, images, tags, labels, test='image.query')
    i_db_data = Dataset(opt, images, tags, labels, test='image.db')
    t_query_data = Dataset(opt, images, tags, labels, test='text.query')
    t_db_data = Dataset(opt, images, tags, labels, test='text.db')

    i_query_dataloader = DataLoader(i_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False)
    t_query_dataloader = DataLoader(t_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False)

    qBX = generate_img_code(generator, i_query_dataloader, opt.query_size)
    qBY = generate_txt_code(generator, t_query_dataloader, opt.query_size)
    rBX = generate_img_code(generator, i_db_dataloader, opt.db_size)
    rBY = generate_txt_code(generator, t_db_dataloader, opt.db_size)

    query_labels, db_labels = i_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
    mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)
    print('...test MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))
Exemple #6
0
def valid(model, x_query_dataloader, x_db_dataloader, y_query_dataloader, y_db_dataloader,
          query_labels, db_labels):
    model.eval()

    qBX = generate_img_code(model, x_query_dataloader, opt.query_size)
    qBY = generate_txt_code(model, y_query_dataloader, opt.query_size)
    rBX = generate_img_code(model, x_db_dataloader, opt.db_size)
    rBY = generate_txt_code(model, y_db_dataloader, opt.db_size)

    mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
    mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)

    model.train()
    return mapi2t.item(), mapt2i.item()
Exemple #7
0
def test(**kwargs):
    opt.parse(kwargs)

    images, tags, labels = load_data(opt.data_path)
    y_dim = tags.shape[1]

    X, Y, L = split_data(images, tags, labels)
    print('...loading and splitting data finish')

    img_model = ImgModule(opt.bit)
    txt_model = TxtModule(y_dim, opt.bit)

    if opt.load_img_path:
        img_model.load(opt.load_img_path)

    if opt.load_txt_path:
        txt_model.load(opt.load_txt_path)

    if opt.use_gpu:
        img_model = img_model.cuda()
        txt_model = txt_model.cuda()
    print('-----------------------')
    query_L = torch.from_numpy(L['query'])
    query_x = torch.from_numpy(X['query'])
    query_y = torch.from_numpy(Y['query'])

    retrieval_L = torch.from_numpy(L['retrieval'])
    retrieval_x = torch.from_numpy(X['retrieval'])
    retrieval_y = torch.from_numpy(Y['retrieval'])

    qBX = generate_image_code(img_model, query_x, opt.bit)
    qBY = generate_text_code(txt_model, query_y, opt.bit)
    rBX = generate_image_code(img_model, retrieval_x, opt.bit)
    rBY = generate_text_code(txt_model, retrieval_y, opt.bit)

    if opt.use_gpu:
        query_L = query_L.cuda()
        retrieval_L = retrieval_L.cuda()

    mapi2t = calc_map_k(qBX, rBY, query_L, retrieval_L)
    mapt2i = calc_map_k(qBY, rBX, query_L, retrieval_L)
    print('...test MAP: MAP(i->t): %3.3f, MAP(t->i): %3.3f' % (mapi2t, mapt2i))
Exemple #8
0
def train_model(model, dataloader, criterion, criterion_hash, optimizer, scheduler, num_epochs, bits, classes, log_file):

    train_codes = calc_train_codes(dataloader, bits, classes)

    for epoch in range(num_epochs):

        model.train()
        ce_loss = 0.0

        for batch_cnt, (inputs, labels, item) in enumerate(dataloader['train']):

            codes = torch.tensor(train_codes[item, :]).float().cuda()
            inputs = inputs.cuda()
            labels = labels.cuda()

            optimizer.zero_grad()
            feature_map, outputs_class, outputs_codes = model(inputs)

            # ------------------------------------------------------------
            attention = torch.sum(feature_map.detach(), dim=1, keepdim=True)
            attention = nn.functional.interpolate(attention, size=(224, 224), mode='bilinear', align_corners=True)
            masks = []
            for i in range(labels.size()[0]):
                threshold = random.uniform(0.9, 1.0)
                mask = (attention[i] < threshold * attention[i].max()).float()
                masks.append(mask)

            masks = torch.stack(masks)
            hide_imgs = inputs * masks
            _, outputs_hide, _ = model(hide_imgs)
            # ------------------------------------------------------------

            loss_class = criterion(outputs_class, labels)
            loss_class_hide = criterion(outputs_hide, labels)
            loss_codes = criterion_hash(outputs_codes, codes)
            loss = loss_class + loss_codes + loss_class_hide  # 0.1*
            loss.backward()
            optimizer.step()
            ce_loss += loss.item() * inputs.size(0)

        epoch_loss = ce_loss / dataloader['train'].total_item_len
        scheduler.step()

        if (epoch+1)%1 == 0:
            ground_q, code_q = eval_turn(model, dataloader['val'])
            ground_d, code_d = eval_turn(model, dataloader['base'])

            labels_onehot_q = label2onehot(ground_q.cpu(), classes)
            labels_onehot_d = label2onehot(ground_d.cpu(), classes)

            map_1 = calc_map_k(torch.sign(code_q), torch.tensor(train_codes).float().cuda(), labels_onehot_q, labels_onehot_d)

            print('epoch:{}:  loss:{:.4f},  MAP:{:.4f}'.format(epoch+1, epoch_loss, map_1))
            log_file.write('epoch:{}:  loss:{:.4f},  MAP:{:.4f}'.format(epoch+1, epoch_loss, map_1) + '\n')
Exemple #9
0
def valid(opt,
          img_model: nn.Module,
          txt_model: nn.Module,
          dataset: DatasetMirflckr25KValid,
          return_hash=False):
    # get query img and txt binary code
    dataset.query()
    qB_img = get_img_code(opt, img_model, dataset)
    qB_txt = get_txt_code(opt, txt_model, dataset)
    query_label = dataset.get_all_label()
    # get retrieval img and txt binary code
    dataset.retrieval()
    rB_img = get_img_code(opt, img_model, dataset)
    rB_txt = get_txt_code(opt, txt_model, dataset)
    retrieval_label = dataset.get_all_label()
    mAPi2t = calc_map_k(qB_img, rB_txt, query_label, retrieval_label)
    mAPt2i = calc_map_k(qB_txt, rB_img, query_label, retrieval_label)
    if return_hash:
        return mAPi2t, mAPt2i, qB_img.cpu(), qB_txt.cpu(), rB_img.cpu(
        ), rB_txt.cpu(), query_label, retrieval_label
    return mAPi2t, mAPt2i
Exemple #10
0
def test(**kwargs):
    opt.parse(kwargs)

    if opt.device is not None:
        opt.device = torch.device(opt.device)
    elif opt.gpus:
        opt.device = torch.device(0)
    else:
        opt.device = torch.device('cpu')

    pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    model = AGAH(opt.bit,
                 opt.tag_dim,
                 opt.num_label,
                 opt.emb_dim,
                 lambd=opt.lambd,
                 pretrain_model=pretrain_model).to(opt.device)

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    load_model(model, path)
    FEATURE_MAP = torch.load(os.path.join(path,
                                          'feature_map.pth')).to(opt.device)

    model.eval()

    images, tags, labels = load_data(opt.data_path, opt.dataset)

    x_query_data = Dataset(opt, images, tags, labels, test='image.query')
    x_db_data = Dataset(opt, images, tags, labels, test='image.db')
    y_query_data = Dataset(opt, images, tags, labels, test='text.query')
    y_db_data = Dataset(opt, images, tags, labels, test='text.db')

    x_query_dataloader = DataLoader(x_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    x_db_dataloader = DataLoader(x_db_data, opt.batch_size, shuffle=False)
    y_query_dataloader = DataLoader(y_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    y_db_dataloader = DataLoader(y_db_data, opt.batch_size, shuffle=False)

    qBX = generate_img_code(model, x_query_dataloader, opt.query_size,
                            FEATURE_MAP)
    qBY = generate_txt_code(model, y_query_dataloader, opt.query_size,
                            FEATURE_MAP)
    rBX = generate_img_code(model, x_db_dataloader, opt.db_size, FEATURE_MAP)
    rBY = generate_txt_code(model, y_db_dataloader, opt.db_size, FEATURE_MAP)

    query_labels, db_labels = x_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    p_i2t, r_i2t = pr_curve(qBX, rBY, query_labels, db_labels)
    p_t2i, r_t2i = pr_curve(qBY, rBX, query_labels, db_labels)

    K = [1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    pk_i2t = p_topK(qBX, rBY, query_labels, db_labels, K)
    pk_t2i = p_topK(qBY, rBX, query_labels, db_labels, K)

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    np.save(os.path.join(path, 'P_i2t.npy'), p_i2t.numpy())
    np.save(os.path.join(path, 'R_i2t.npy'), r_i2t.numpy())
    np.save(os.path.join(path, 'P_t2i.npy'), p_t2i.numpy())
    np.save(os.path.join(path, 'R_t2i.npy'), r_t2i.numpy())
    np.save(os.path.join(path, 'P_at_K_i2t.npy'), pk_i2t.numpy())
    np.save(os.path.join(path, 'P_at_K_t2i.npy'), pk_t2i.numpy())

    mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
    mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)
    print('...test MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))
Exemple #11
0
def test(**kwargs):
    opt.parse(kwargs)

    if opt.device is not None:
        opt.device = torch.device(opt.device)
    elif opt.gpus:
        opt.device = torch.device(0)
    else:
        opt.device = torch.device('cpu')

    with torch.no_grad():
        model = CPAH(opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, opt.num_label).to(opt.device)

        path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc)
        load_model(model, path)

        model.eval()

        images, tags, labels = load_data(opt.data_path, opt.dataset)

        i_query_data = Dataset(opt, images, tags, labels, test='image.query')
        i_db_data = Dataset(opt, images, tags, labels, test='image.db')
        t_query_data = Dataset(opt, images, tags, labels, test='text.query')
        t_db_data = Dataset(opt, images, tags, labels, test='text.db')

        i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False)
        i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False)
        t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False)
        t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False)

        qBX = generate_img_code(model, i_query_dataloader, opt.query_size)
        qBY = generate_txt_code(model, t_query_dataloader, opt.query_size)
        rBX = generate_img_code(model, i_db_dataloader, opt.db_size)
        rBY = generate_txt_code(model, t_db_dataloader, opt.db_size)

        query_labels, db_labels = i_query_data.get_labels()
        query_labels = query_labels.to(opt.device)
        db_labels = db_labels.to(opt.device)

        #K = [1, 10, 100, 1000]
        #p_top_k(qBX, rBY, query_labels, db_labels, K, tqdm_label='I2T')
        # pr_curve2(qBY, rBX, query_labels, db_labels)

        p_i2t, r_i2t = pr_curve(qBX, rBY, query_labels, db_labels, tqdm_label='I2T')
        p_t2i, r_t2i = pr_curve(qBY, rBX, query_labels, db_labels, tqdm_label='T2I')
        p_i2i, r_i2i = pr_curve(qBX, rBX, query_labels, db_labels, tqdm_label='I2I')
        p_t2t, r_t2t = pr_curve(qBY, rBY, query_labels, db_labels, tqdm_label='T2T')

        K = [1, 10, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]
        pk_i2t = p_top_k(qBX, rBY, query_labels, db_labels, K, tqdm_label='I2T')
        pk_t2i = p_top_k(qBY, rBX, query_labels, db_labels, K, tqdm_label='T2I')
        pk_i2i = p_top_k(qBX, rBX, query_labels, db_labels, K, tqdm_label='I2I')
        pk_t2t = p_top_k(qBY, rBY, query_labels, db_labels, K, tqdm_label='T2T')

        mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
        mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)
        mapi2i = calc_map_k(qBX, rBX, query_labels, db_labels)
        mapt2t = calc_map_k(qBY, rBY, query_labels, db_labels)

        pr_dict = {'pi2t': p_i2t.cpu().numpy(), 'ri2t': r_i2t.cpu().numpy(),
                   'pt2i': p_t2i.cpu().numpy(), 'rt2i': r_t2i.cpu().numpy(),
                   'pi2i': p_i2i.cpu().numpy(), 'ri2i': r_i2i.cpu().numpy(),
                   'pt2t': p_t2t.cpu().numpy(), 'rt2t': r_t2t.cpu().numpy()}

        pk_dict = {'k': K,
                   'pki2t': pk_i2t.cpu().numpy(),
                   'pkt2i': pk_t2i.cpu().numpy(),
                   'pki2i': pk_i2i.cpu().numpy(),
                   'pkt2t': pk_t2t.cpu().numpy()}

        map_dict = {'mapi2t': float(mapi2t.cpu().numpy()),
                    'mapt2i': float(mapt2i.cpu().numpy()),
                    'mapi2i': float(mapi2i.cpu().numpy()),
                    'mapt2t': float(mapt2t.cpu().numpy())}

        print('   Test MAP: MAP(i->t) = {:3.4f}, MAP(t->i) = {:3.4f}, MAP(i->i) = {:3.4f}, MAP(t->t) = {:3.4f}'.format(mapi2t, mapt2i, mapi2i, mapt2t))

        path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc)
        write_pickle(os.path.join(path, 'pr_dict.pkl'), pr_dict)
        write_pickle(os.path.join(path, 'pk_dict.pkl'), pk_dict)
        write_pickle(os.path.join(path, 'map_dict.pkl'), map_dict)
Exemple #12
0
    def test(self):
        self.ImgNet.eval().cuda()
        self.TxtNet.eval().cuda()

        re_BI, re_BT, re_LT, qu_BI, qu_BT, qu_LT = generate_hashes_from_dataloader(
            self.database_loader, self.test_loader, self.ImgNet, self.TxtNet,
            self.cfg.LABEL_DIM)

        qu_BI = self.get_each_5th_element(qu_BI)
        re_BI = self.get_each_5th_element(re_BI)
        qu_LI = self.get_each_5th_element(qu_LT)
        re_LI = self.get_each_5th_element(re_LT)

        p_i2t, r_i2t = pr_curve(qu_BI, re_BT, qu_LI, re_LT, tqdm_label='I2T')
        p_t2i, r_t2i = pr_curve(qu_BT, re_BI, qu_LT, re_LI, tqdm_label='T2I')
        p_i2i, r_i2i = pr_curve(qu_BI, re_BI, qu_LI, re_LI, tqdm_label='I2I')
        p_t2t, r_t2t = pr_curve(qu_BT, re_BT, qu_LT, re_LT, tqdm_label='T2T')

        K = [1, 10, 50] + list(range(100, 1000, 100)) + list(
            range(1000, 10001, 1000))
        pk_i2t = p_top_k(qu_BI, re_BT, qu_LI, re_LT, K, tqdm_label='I2T')
        pk_t2i = p_top_k(qu_BT, re_BI, qu_LT, re_LI, K, tqdm_label='T2I')
        pk_i2i = p_top_k(qu_BI, re_BI, qu_LI, re_LI, K, tqdm_label='I2I')
        pk_t2t = p_top_k(qu_BT, re_BT, qu_LT, re_LT, K, tqdm_label='T2T')

        MAP_I2T = calc_map_k(qu_BI, re_BT, qu_LI, re_LT, self.cfg.MAP_K)
        MAP_T2I = calc_map_k(qu_BT, re_BI, qu_LT, re_LI, self.cfg.MAP_K)
        MAP_I2I = calc_map_k(qu_BI, re_BI, qu_LI, re_LI, self.cfg.MAP_K)
        MAP_T2T = calc_map_k(qu_BT, re_BT, qu_LT, re_LT, self.cfg.MAP_K)
        MAPS = (MAP_I2T, MAP_T2I, MAP_I2I, MAP_T2T)

        pr_dict = {
            'pi2t': p_i2t.cpu().numpy(),
            'ri2t': r_i2t.cpu().numpy(),
            'pt2i': p_t2i.cpu().numpy(),
            'rt2i': r_t2i.cpu().numpy(),
            'pi2i': p_i2i.cpu().numpy(),
            'ri2i': r_i2i.cpu().numpy(),
            'pt2t': p_t2t.cpu().numpy(),
            'rt2t': r_t2t.cpu().numpy()
        }

        pk_dict = {
            'k': K,
            'pki2t': pk_i2t.cpu().numpy(),
            'pkt2i': pk_t2i.cpu().numpy(),
            'pki2i': pk_i2i.cpu().numpy(),
            'pkt2t': pk_t2t.cpu().numpy()
        }

        map_dict = {
            'mapi2t': float(MAP_I2T.cpu().numpy()),
            'mapt2i': float(MAP_T2I.cpu().numpy()),
            'mapi2i': float(MAP_I2I.cpu().numpy()),
            'mapt2t': float(MAP_T2T.cpu().numpy())
        }

        self.logger.info(
            'mAP I->T: %.3f, mAP T->I: %.3f, mAP I->I: %.3f, mAP T->T: %.3f' %
            MAPS)

        write_pickle(osp.join(self.cfg.MODEL_DIR, self.path, 'pr_dict.pkl'),
                     pr_dict)
        write_pickle(osp.join(self.cfg.MODEL_DIR, self.path, 'pk_dict.pkl'),
                     pk_dict)
        write_pickle(osp.join(self.cfg.MODEL_DIR, self.path, 'map_dict.pkl'),
                     map_dict)