Exemplo n.º 1
0
def test(**kwargs):
    opt.parse(kwargs)

    if opt.device is not None:
        opt.device = torch.device(opt.device)
    elif opt.gpus:
        opt.device = torch.device(0)
    else:
        opt.device = torch.device('cpu')

    pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    generator = GEN(opt.dropout,
                    opt.image_dim,
                    opt.text_dim,
                    opt.hidden_dim,
                    opt.bit,
                    pretrain_model=pretrain_model).to(opt.device)

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    load_model(generator, path)

    generator.eval()

    images, tags, labels = load_data(opt.data_path, opt.dataset)

    i_query_data = Dataset(opt, images, tags, labels, test='image.query')
    i_db_data = Dataset(opt, images, tags, labels, test='image.db')
    t_query_data = Dataset(opt, images, tags, labels, test='text.query')
    t_db_data = Dataset(opt, images, tags, labels, test='text.db')

    i_query_dataloader = DataLoader(i_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False)
    t_query_dataloader = DataLoader(t_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False)

    qBX = generate_img_code(generator, i_query_dataloader, opt.query_size)
    qBY = generate_txt_code(generator, t_query_dataloader, opt.query_size)
    rBX = generate_img_code(generator, i_db_dataloader, opt.db_size)
    rBY = generate_txt_code(generator, t_db_dataloader, opt.db_size)

    query_labels, db_labels = i_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
    mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)
    print('...test MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))
Exemplo n.º 2
0
Arquivo: main.py Projeto: bei21/DADH
def train(**kwargs):
    opt.parse(kwargs)

    if opt.vis_env:
        vis = Visualizer(opt.vis_env, port=opt.vis_port)

    if opt.device is None or opt.device is 'cpu':
        opt.device = torch.device('cpu')
    else:
        opt.device = torch.device(opt.device)

    images, tags, labels = load_data(opt.data_path, type=opt.dataset)
    train_data = Dataset(opt, images, tags, labels)
    train_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True)
    L = train_data.get_labels()
    L = L.to(opt.device)
    # test
    i_query_data = Dataset(opt, images, tags, labels, test='image.query')
    i_db_data = Dataset(opt, images, tags, labels, test='image.db')
    t_query_data = Dataset(opt, images, tags, labels, test='text.query')
    t_db_data = Dataset(opt, images, tags, labels, test='text.db')

    i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False)
    i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False)
    t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False)
    t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False)

    query_labels, db_labels = i_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    generator = GEN(opt.dropout, opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, opt.num_label, pretrain_model=pretrain_model).to(opt.device)

    discriminator = DIS(opt.hidden_dim//4, opt.hidden_dim//8, opt.bit).to(opt.device)

    optimizer = Adam([
        # {'params': generator.cnn_f.parameters()},     ## froze parameters of cnn_f
        {'params': generator.image_module.parameters()},
        {'params': generator.text_module.parameters()},
        {'params': generator.hash_module.parameters()}
    ], lr=opt.lr, weight_decay=0.0005)

    optimizer_dis = {
        'feature': Adam(discriminator.feature_dis.parameters(), lr=opt.lr, betas=(0.5, 0.9), weight_decay=0.0001),
        'hash': Adam(discriminator.hash_dis.parameters(), lr=opt.lr, betas=(0.5, 0.9), weight_decay=0.0001)
    }

    tri_loss = TripletLoss(opt, reduction='sum')

    loss = []

    max_mapi2t = 0.
    max_mapt2i = 0.
    max_average = 0.

    mapt2i_list = []
    mapi2t_list = []
    train_times = []

    B_i = torch.randn(opt.training_size, opt.bit).sign().to(opt.device)
    B_t = B_i
    H_i = torch.zeros(opt.training_size, opt.bit).to(opt.device)
    H_t = torch.zeros(opt.training_size, opt.bit).to(opt.device)

    for epoch in range(opt.max_epoch):
        t1 = time.time()
        e_loss = 0
        for i, (ind, img, txt, label) in tqdm(enumerate(train_dataloader)):
            imgs = img.to(opt.device)
            txt = txt.to(opt.device)
            labels = label.to(opt.device)

            batch_size = len(ind)

            h_i, h_t, f_i, f_t = generator(imgs, txt)
            H_i[ind, :] = h_i.data
            H_t[ind, :] = h_t.data
            h_t_detach = generator.generate_txt_code(txt)

            #####
            # train feature discriminator
            #####
            D_real_feature = discriminator.dis_feature(f_i.detach())
            D_real_feature = -opt.gamma * torch.log(torch.sigmoid(D_real_feature)).mean()
            # D_real_feature = -D_real_feature.mean()
            optimizer_dis['feature'].zero_grad()
            D_real_feature.backward()

            # train with fake
            D_fake_feature = discriminator.dis_feature(f_t.detach())
            D_fake_feature = -opt.gamma * torch.log(torch.ones(batch_size).to(opt.device) - torch.sigmoid(D_fake_feature)).mean()
            # D_fake_feature = D_fake_feature.mean()
            D_fake_feature.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.hidden_dim//4).to(opt.device)
            interpolates = alpha * f_i.detach() + (1 - alpha) * f_t.detach()
            interpolates.requires_grad_()
            disc_interpolates = discriminator.dis_feature(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                      grad_outputs=torch.ones(disc_interpolates.size()).to(opt.device),
                                      create_graph=True, retain_graph=True, only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)
            # 10 is gradient penalty hyperparameter
            feature_gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
            feature_gradient_penalty.backward()

            optimizer_dis['feature'].step()

            #####
            # train hash discriminator
            #####
            D_real_hash = discriminator.dis_hash(h_i.detach())
            D_real_hash = -opt.gamma * torch.log(torch.sigmoid(D_real_hash)).mean()
            optimizer_dis['hash'].zero_grad()
            D_real_hash.backward()

            # train with fake
            D_fake_hash = discriminator.dis_hash(h_t.detach())
            D_fake_hash = -opt.gamma * torch.log(torch.ones(batch_size).to(opt.device) - torch.sigmoid(D_fake_hash)).mean()
            D_fake_hash.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.bit).to(opt.device)
            interpolates = alpha * h_i.detach() + (1 - alpha) * h_t.detach()
            interpolates.requires_grad_()
            disc_interpolates = discriminator.dis_hash(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                      grad_outputs=torch.ones(disc_interpolates.size()).to(opt.device),
                                      create_graph=True, retain_graph=True, only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)

            hash_gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
            hash_gradient_penalty.backward()

            optimizer_dis['hash'].step()

            loss_G_txt_feature = -torch.log(torch.sigmoid(discriminator.dis_feature(f_t))).mean()
            loss_adver_feature = loss_G_txt_feature

            loss_G_txt_hash = -torch.log(torch.sigmoid(discriminator.dis_hash(h_t_detach))).mean()
            loss_adver_hash = loss_G_txt_hash

            tri_i2t = tri_loss(h_i, labels, target=h_t, margin=opt.margin)
            tri_t2i = tri_loss(h_t, labels, target=h_i, margin=opt.margin)
            weighted_cos_tri = tri_i2t + tri_t2i

            i_ql = torch.sum(torch.pow(B_i[ind, :] - h_i, 2))
            t_ql = torch.sum(torch.pow(B_t[ind, :] - h_t, 2))
            loss_quant = i_ql + t_ql
            err = opt.alpha * weighted_cos_tri + \
                  opt.beta * loss_quant + opt.gamma * (loss_adver_feature + loss_adver_hash)

            optimizer.zero_grad()
            err.backward()
            optimizer.step()

            e_loss = err + e_loss

        P_i = torch.inverse(
                L.t() @ L + opt.lamb * torch.eye(opt.num_label, device=opt.device)) @ L.t() @ B_i
        P_t = torch.inverse(
                L.t() @ L + opt.lamb * torch.eye(opt.num_label, device=opt.device)) @ L.t() @ B_t

        B_i = (L @ P_i + opt.mu * H_i).sign()
        B_t = (L @ P_t + opt.mu * H_t).sign()
        loss.append(e_loss.item())
        print('...epoch: %3d, loss: %3.3f' % (epoch + 1, loss[-1]))
        delta_t = time.time() - t1

        if opt.vis_env:
            vis.plot('loss', loss[-1])

        # validate
        if opt.valid and (epoch + 1) % opt.valid_freq == 0:
            mapi2t, mapt2i = valid(generator, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader,
                                   query_labels, db_labels)
            print('...epoch: %3d, valid MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (epoch + 1, mapi2t, mapt2i))

            mapi2t_list.append(mapi2t)
            mapt2i_list.append(mapt2i)
            train_times.append(delta_t)

            if 0.5 * (mapi2t + mapt2i) > max_average:
                max_mapi2t = mapi2t
                max_mapt2i = mapt2i
                max_average = 0.5 * (mapi2t + mapt2i)
                save_model(generator)

            if opt.vis_env:
                vis.plot('mapi2t', mapi2t)
                vis.plot('mapt2i', mapt2i)

        if epoch % 100 == 0:
            for params in optimizer.param_groups:
                params['lr'] = max(params['lr'] * 0.8, 1e-6)

    if not opt.valid:
        save_model(generator)

    print('...training procedure finish')
    if opt.valid:
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (max_mapi2t, max_mapt2i))
    else:
        mapi2t, mapt2i = valid(generator, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader,
                               query_labels, db_labels)
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    with open(os.path.join(path, 'result.pkl'), 'wb') as f:
        pickle.dump([train_times, mapi2t_list, mapt2i_list], f)
Exemplo n.º 3
0
def test(**kwargs):
    opt.parse(kwargs)

    if opt.device is not None:
        opt.device = torch.device(opt.device)
    elif opt.gpus:
        opt.device = torch.device(0)
    else:
        opt.device = torch.device('cpu')

    pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    model = AGAH(opt.bit,
                 opt.tag_dim,
                 opt.num_label,
                 opt.emb_dim,
                 lambd=opt.lambd,
                 pretrain_model=pretrain_model).to(opt.device)

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    load_model(model, path)
    FEATURE_MAP = torch.load(os.path.join(path,
                                          'feature_map.pth')).to(opt.device)

    model.eval()

    images, tags, labels = load_data(opt.data_path, opt.dataset)

    x_query_data = Dataset(opt, images, tags, labels, test='image.query')
    x_db_data = Dataset(opt, images, tags, labels, test='image.db')
    y_query_data = Dataset(opt, images, tags, labels, test='text.query')
    y_db_data = Dataset(opt, images, tags, labels, test='text.db')

    x_query_dataloader = DataLoader(x_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    x_db_dataloader = DataLoader(x_db_data, opt.batch_size, shuffle=False)
    y_query_dataloader = DataLoader(y_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    y_db_dataloader = DataLoader(y_db_data, opt.batch_size, shuffle=False)

    qBX = generate_img_code(model, x_query_dataloader, opt.query_size,
                            FEATURE_MAP)
    qBY = generate_txt_code(model, y_query_dataloader, opt.query_size,
                            FEATURE_MAP)
    rBX = generate_img_code(model, x_db_dataloader, opt.db_size, FEATURE_MAP)
    rBY = generate_txt_code(model, y_db_dataloader, opt.db_size, FEATURE_MAP)

    query_labels, db_labels = x_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    p_i2t, r_i2t = pr_curve(qBX, rBY, query_labels, db_labels)
    p_t2i, r_t2i = pr_curve(qBY, rBX, query_labels, db_labels)

    K = [1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    pk_i2t = p_topK(qBX, rBY, query_labels, db_labels, K)
    pk_t2i = p_topK(qBY, rBX, query_labels, db_labels, K)

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    np.save(os.path.join(path, 'P_i2t.npy'), p_i2t.numpy())
    np.save(os.path.join(path, 'R_i2t.npy'), r_i2t.numpy())
    np.save(os.path.join(path, 'P_t2i.npy'), p_t2i.numpy())
    np.save(os.path.join(path, 'R_t2i.npy'), r_t2i.numpy())
    np.save(os.path.join(path, 'P_at_K_i2t.npy'), pk_i2t.numpy())
    np.save(os.path.join(path, 'P_at_K_t2i.npy'), pk_t2i.numpy())

    mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
    mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)
    print('...test MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))
Exemplo n.º 4
0
def train(**kwargs):
    opt.parse(kwargs)

    if opt.vis_env:
        vis = Visualizer(opt.vis_env, port=opt.vis_port)

    if opt.device is None or opt.device is 'cpu':
        opt.device = torch.device('cpu')
    else:
        opt.device = torch.device(opt.device)

    images, tags, labels = load_data(opt.data_path, type=opt.dataset)

    train_data = Dataset(opt, images, tags, labels)
    train_dataloader = DataLoader(train_data,
                                  batch_size=opt.batch_size,
                                  shuffle=True)

    # valid or test data
    x_query_data = Dataset(opt, images, tags, labels, test='image.query')
    x_db_data = Dataset(opt, images, tags, labels, test='image.db')
    y_query_data = Dataset(opt, images, tags, labels, test='text.query')
    y_db_data = Dataset(opt, images, tags, labels, test='text.db')

    x_query_dataloader = DataLoader(x_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    x_db_dataloader = DataLoader(x_db_data, opt.batch_size, shuffle=False)
    y_query_dataloader = DataLoader(y_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    y_db_dataloader = DataLoader(y_db_data, opt.batch_size, shuffle=False)

    query_labels, db_labels = x_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    if opt.load_model_path:
        pretrain_model = None
    elif opt.pretrain_model_path:
        pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    model = AGAH(opt.bit,
                 opt.tag_dim,
                 opt.num_label,
                 opt.emb_dim,
                 lambd=opt.lambd,
                 pretrain_model=pretrain_model).to(opt.device)

    load_model(model, opt.load_model_path)

    optimizer = Adamax([{
        'params': model.img_module.parameters(),
        'lr': opt.lr
    }, {
        'params': model.txt_module.parameters()
    }, {
        'params': model.hash_module.parameters()
    }, {
        'params': model.classifier.parameters()
    }],
                       lr=opt.lr * 10,
                       weight_decay=0.0005)

    optimizer_dis = {
        'img':
        Adamax(model.img_discriminator.parameters(),
               lr=opt.lr * 10,
               betas=(0.5, 0.9),
               weight_decay=0.0001),
        'txt':
        Adamax(model.txt_discriminator.parameters(),
               lr=opt.lr * 10,
               betas=(0.5, 0.9),
               weight_decay=0.0001)
    }

    criterion_tri_cos = TripletAllLoss(dis_metric='cos', reduction='sum')
    criterion_bce = nn.BCELoss(reduction='sum')

    loss = []

    max_mapi2t = 0.
    max_mapt2i = 0.

    FEATURE_I = torch.randn(opt.training_size, opt.emb_dim).to(opt.device)
    FEATURE_T = torch.randn(opt.training_size, opt.emb_dim).to(opt.device)

    U = torch.randn(opt.training_size, opt.bit).to(opt.device)
    V = torch.randn(opt.training_size, opt.bit).to(opt.device)

    FEATURE_MAP = torch.randn(opt.num_label, opt.emb_dim).to(opt.device)
    CODE_MAP = torch.sign(torch.randn(opt.num_label, opt.bit)).to(opt.device)

    train_labels = train_data.get_labels().to(opt.device)

    mapt2i_list = []
    mapi2t_list = []
    train_times = []

    for epoch in range(opt.max_epoch):
        t1 = time.time()
        for i, (ind, x, y, l) in tqdm(enumerate(train_dataloader)):
            imgs = x.to(opt.device)
            tags = y.to(opt.device)
            labels = l.to(opt.device)

            batch_size = len(ind)

            h_x, h_y, f_x, f_y, x_class, y_class = model(
                imgs, tags, FEATURE_MAP)

            FEATURE_I[ind] = f_x.data
            FEATURE_T[ind] = f_y.data
            U[ind] = h_x.data
            V[ind] = h_y.data

            #####
            # train txt discriminator
            #####
            D_txt_real = model.dis_txt(f_y.detach())
            D_txt_real = -D_txt_real.mean()
            optimizer_dis['txt'].zero_grad()
            D_txt_real.backward()

            # train with fake
            D_txt_fake = model.dis_txt(f_x.detach())
            D_txt_fake = D_txt_fake.mean()
            D_txt_fake.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.emb_dim).to(opt.device)
            interpolates = alpha * f_y.detach() + (1 - alpha) * f_x.detach()
            interpolates.requires_grad_()
            disc_interpolates = model.dis_txt(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates,
                                      inputs=interpolates,
                                      grad_outputs=torch.ones(
                                          disc_interpolates.size()).to(
                                              opt.device),
                                      create_graph=True,
                                      retain_graph=True,
                                      only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)
            # 10 is gradient penalty hyperparameter
            txt_gradient_penalty = (
                (gradients.norm(2, dim=1) - 1)**2).mean() * 10
            txt_gradient_penalty.backward()

            loss_D_txt = D_txt_real - D_txt_fake
            optimizer_dis['txt'].step()

            #####
            # train img discriminator
            #####
            D_img_real = model.dis_img(f_x.detach())
            D_img_real = -D_img_real.mean()
            optimizer_dis['img'].zero_grad()
            D_img_real.backward()

            # train with fake
            D_img_fake = model.dis_img(f_y.detach())
            D_img_fake = D_img_fake.mean()
            D_img_fake.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.emb_dim).to(opt.device)
            interpolates = alpha * f_x.detach() + (1 - alpha) * f_y.detach()
            interpolates.requires_grad_()
            disc_interpolates = model.dis_img(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates,
                                      inputs=interpolates,
                                      grad_outputs=torch.ones(
                                          disc_interpolates.size()).to(
                                              opt.device),
                                      create_graph=True,
                                      retain_graph=True,
                                      only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)
            # 10 is gradient penalty hyperparameter
            img_gradient_penalty = (
                (gradients.norm(2, dim=1) - 1)**2).mean() * 10
            img_gradient_penalty.backward()

            loss_D_img = D_img_real - D_img_fake
            optimizer_dis['img'].step()

            #####
            # train generators
            #####
            # update img network (to generate txt features)
            domain_output = model.dis_txt(f_x)
            loss_G_txt = -domain_output.mean()

            # update txt network (to generate img features)
            domain_output = model.dis_img(f_y)
            loss_G_img = -domain_output.mean()

            loss_adver = loss_G_txt + loss_G_img

            loss1 = criterion_tri_cos(h_x,
                                      labels,
                                      target=h_y,
                                      margin=opt.margin)
            loss2 = criterion_tri_cos(h_y,
                                      labels,
                                      target=h_x,
                                      margin=opt.margin)

            theta1 = F.cosine_similarity(torch.abs(h_x),
                                         torch.ones_like(h_x).to(opt.device))
            theta2 = F.cosine_similarity(torch.abs(h_y),
                                         torch.ones_like(h_y).to(opt.device))
            loss3 = torch.sum(1 / (1 + torch.exp(theta1))) + torch.sum(
                1 / (1 + torch.exp(theta2)))

            loss_class = criterion_bce(x_class, labels) + criterion_bce(
                y_class, labels)

            theta_code_x = h_x.mm(CODE_MAP.t())  # size: (batch, num_label)
            theta_code_y = h_y.mm(CODE_MAP.t())
            loss_code_map = torch.sum(torch.pow(theta_code_x - opt.bit * (labels * 2 - 1), 2)) + \
                            torch.sum(torch.pow(theta_code_y - opt.bit * (labels * 2 - 1), 2))

            loss_quant = torch.sum(torch.pow(
                h_x - torch.sign(h_x), 2)) + torch.sum(
                    torch.pow(h_y - torch.sign(h_y), 2))

            # err = loss1 + loss2 + loss3 + 0.5 * loss_class + 0.5 * (loss_f1 + loss_f2)
            err = loss1 + loss2 + opt.alpha * loss3 + opt.beta * loss_class + opt.gamma * loss_code_map + \
                  opt.eta * loss_quant + opt.mu * loss_adver

            optimizer.zero_grad()
            err.backward()
            optimizer.step()

            loss.append(err.item())

        CODE_MAP = update_code_map(U, V, CODE_MAP, train_labels)
        FEATURE_MAP = update_feature_map(FEATURE_I, FEATURE_T, train_labels)

        print('...epoch: %3d, loss: %3.3f' % (epoch + 1, loss[-1]))
        delta_t = time.time() - t1

        if opt.vis_env:
            vis.plot('loss', loss[-1])

        # validate
        if opt.valid and (epoch + 1) % opt.valid_freq == 0:
            mapi2t, mapt2i = valid(model, x_query_dataloader, x_db_dataloader,
                                   y_query_dataloader, y_db_dataloader,
                                   query_labels, db_labels, FEATURE_MAP)
            print(
                '...epoch: %3d, valid MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f'
                % (epoch + 1, mapi2t, mapt2i))

            mapi2t_list.append(mapi2t)
            mapt2i_list.append(mapt2i)
            train_times.append(delta_t)

            if opt.vis_env:
                d = {'mapi2t': mapi2t, 'mapt2i': mapt2i}
                vis.plot_many(d)

            if mapt2i >= max_mapt2i and mapi2t >= max_mapi2t:
                max_mapi2t = mapi2t
                max_mapt2i = mapt2i
                save_model(model)
                path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
                with torch.cuda.device(opt.device):
                    torch.save(FEATURE_MAP,
                               os.path.join(path, 'feature_map.pth'))

        if epoch % 100 == 0:
            for params in optimizer.param_groups:
                params['lr'] = max(params['lr'] * 0.6, 1e-6)

    if not opt.valid:
        save_model(model)

    print('...training procedure finish')
    if opt.valid:
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' %
              (max_mapi2t, max_mapt2i))
    else:
        mapi2t, mapt2i = valid(model, x_query_dataloader, x_db_dataloader,
                               y_query_dataloader, y_db_dataloader,
                               query_labels, db_labels, FEATURE_MAP)
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' %
              (mapi2t, mapt2i))

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    with open(os.path.join(path, 'result.pkl'), 'wb') as f:
        pickle.dump([train_times, mapi2t_list, mapt2i_list], f)