コード例 #1
0
def test(**kwargs):
    opt.parse(kwargs)

    if opt.device is not None:
        opt.device = torch.device(opt.device)
    elif opt.gpus:
        opt.device = torch.device(0)
    else:
        opt.device = torch.device('cpu')

    pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    generator = GEN(opt.dropout,
                    opt.image_dim,
                    opt.text_dim,
                    opt.hidden_dim,
                    opt.bit,
                    pretrain_model=pretrain_model).to(opt.device)

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    load_model(generator, path)

    generator.eval()

    images, tags, labels = load_data(opt.data_path, opt.dataset)

    i_query_data = Dataset(opt, images, tags, labels, test='image.query')
    i_db_data = Dataset(opt, images, tags, labels, test='image.db')
    t_query_data = Dataset(opt, images, tags, labels, test='text.query')
    t_db_data = Dataset(opt, images, tags, labels, test='text.db')

    i_query_dataloader = DataLoader(i_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False)
    t_query_dataloader = DataLoader(t_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False)

    qBX = generate_img_code(generator, i_query_dataloader, opt.query_size)
    qBY = generate_txt_code(generator, t_query_dataloader, opt.query_size)
    rBX = generate_img_code(generator, i_db_dataloader, opt.db_size)
    rBY = generate_txt_code(generator, t_db_dataloader, opt.db_size)

    query_labels, db_labels = i_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
    mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)
    print('...test MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))
コード例 #2
0
ファイル: main.py プロジェクト: bei21/DADH
def train(**kwargs):
    opt.parse(kwargs)

    if opt.vis_env:
        vis = Visualizer(opt.vis_env, port=opt.vis_port)

    if opt.device is None or opt.device is 'cpu':
        opt.device = torch.device('cpu')
    else:
        opt.device = torch.device(opt.device)

    images, tags, labels = load_data(opt.data_path, type=opt.dataset)
    train_data = Dataset(opt, images, tags, labels)
    train_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True)
    L = train_data.get_labels()
    L = L.to(opt.device)
    # test
    i_query_data = Dataset(opt, images, tags, labels, test='image.query')
    i_db_data = Dataset(opt, images, tags, labels, test='image.db')
    t_query_data = Dataset(opt, images, tags, labels, test='text.query')
    t_db_data = Dataset(opt, images, tags, labels, test='text.db')

    i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False)
    i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False)
    t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False)
    t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False)

    query_labels, db_labels = i_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    generator = GEN(opt.dropout, opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, opt.num_label, pretrain_model=pretrain_model).to(opt.device)

    discriminator = DIS(opt.hidden_dim//4, opt.hidden_dim//8, opt.bit).to(opt.device)

    optimizer = Adam([
        # {'params': generator.cnn_f.parameters()},     ## froze parameters of cnn_f
        {'params': generator.image_module.parameters()},
        {'params': generator.text_module.parameters()},
        {'params': generator.hash_module.parameters()}
    ], lr=opt.lr, weight_decay=0.0005)

    optimizer_dis = {
        'feature': Adam(discriminator.feature_dis.parameters(), lr=opt.lr, betas=(0.5, 0.9), weight_decay=0.0001),
        'hash': Adam(discriminator.hash_dis.parameters(), lr=opt.lr, betas=(0.5, 0.9), weight_decay=0.0001)
    }

    tri_loss = TripletLoss(opt, reduction='sum')

    loss = []

    max_mapi2t = 0.
    max_mapt2i = 0.
    max_average = 0.

    mapt2i_list = []
    mapi2t_list = []
    train_times = []

    B_i = torch.randn(opt.training_size, opt.bit).sign().to(opt.device)
    B_t = B_i
    H_i = torch.zeros(opt.training_size, opt.bit).to(opt.device)
    H_t = torch.zeros(opt.training_size, opt.bit).to(opt.device)

    for epoch in range(opt.max_epoch):
        t1 = time.time()
        e_loss = 0
        for i, (ind, img, txt, label) in tqdm(enumerate(train_dataloader)):
            imgs = img.to(opt.device)
            txt = txt.to(opt.device)
            labels = label.to(opt.device)

            batch_size = len(ind)

            h_i, h_t, f_i, f_t = generator(imgs, txt)
            H_i[ind, :] = h_i.data
            H_t[ind, :] = h_t.data
            h_t_detach = generator.generate_txt_code(txt)

            #####
            # train feature discriminator
            #####
            D_real_feature = discriminator.dis_feature(f_i.detach())
            D_real_feature = -opt.gamma * torch.log(torch.sigmoid(D_real_feature)).mean()
            # D_real_feature = -D_real_feature.mean()
            optimizer_dis['feature'].zero_grad()
            D_real_feature.backward()

            # train with fake
            D_fake_feature = discriminator.dis_feature(f_t.detach())
            D_fake_feature = -opt.gamma * torch.log(torch.ones(batch_size).to(opt.device) - torch.sigmoid(D_fake_feature)).mean()
            # D_fake_feature = D_fake_feature.mean()
            D_fake_feature.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.hidden_dim//4).to(opt.device)
            interpolates = alpha * f_i.detach() + (1 - alpha) * f_t.detach()
            interpolates.requires_grad_()
            disc_interpolates = discriminator.dis_feature(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                      grad_outputs=torch.ones(disc_interpolates.size()).to(opt.device),
                                      create_graph=True, retain_graph=True, only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)
            # 10 is gradient penalty hyperparameter
            feature_gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
            feature_gradient_penalty.backward()

            optimizer_dis['feature'].step()

            #####
            # train hash discriminator
            #####
            D_real_hash = discriminator.dis_hash(h_i.detach())
            D_real_hash = -opt.gamma * torch.log(torch.sigmoid(D_real_hash)).mean()
            optimizer_dis['hash'].zero_grad()
            D_real_hash.backward()

            # train with fake
            D_fake_hash = discriminator.dis_hash(h_t.detach())
            D_fake_hash = -opt.gamma * torch.log(torch.ones(batch_size).to(opt.device) - torch.sigmoid(D_fake_hash)).mean()
            D_fake_hash.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.bit).to(opt.device)
            interpolates = alpha * h_i.detach() + (1 - alpha) * h_t.detach()
            interpolates.requires_grad_()
            disc_interpolates = discriminator.dis_hash(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                      grad_outputs=torch.ones(disc_interpolates.size()).to(opt.device),
                                      create_graph=True, retain_graph=True, only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)

            hash_gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
            hash_gradient_penalty.backward()

            optimizer_dis['hash'].step()

            loss_G_txt_feature = -torch.log(torch.sigmoid(discriminator.dis_feature(f_t))).mean()
            loss_adver_feature = loss_G_txt_feature

            loss_G_txt_hash = -torch.log(torch.sigmoid(discriminator.dis_hash(h_t_detach))).mean()
            loss_adver_hash = loss_G_txt_hash

            tri_i2t = tri_loss(h_i, labels, target=h_t, margin=opt.margin)
            tri_t2i = tri_loss(h_t, labels, target=h_i, margin=opt.margin)
            weighted_cos_tri = tri_i2t + tri_t2i

            i_ql = torch.sum(torch.pow(B_i[ind, :] - h_i, 2))
            t_ql = torch.sum(torch.pow(B_t[ind, :] - h_t, 2))
            loss_quant = i_ql + t_ql
            err = opt.alpha * weighted_cos_tri + \
                  opt.beta * loss_quant + opt.gamma * (loss_adver_feature + loss_adver_hash)

            optimizer.zero_grad()
            err.backward()
            optimizer.step()

            e_loss = err + e_loss

        P_i = torch.inverse(
                L.t() @ L + opt.lamb * torch.eye(opt.num_label, device=opt.device)) @ L.t() @ B_i
        P_t = torch.inverse(
                L.t() @ L + opt.lamb * torch.eye(opt.num_label, device=opt.device)) @ L.t() @ B_t

        B_i = (L @ P_i + opt.mu * H_i).sign()
        B_t = (L @ P_t + opt.mu * H_t).sign()
        loss.append(e_loss.item())
        print('...epoch: %3d, loss: %3.3f' % (epoch + 1, loss[-1]))
        delta_t = time.time() - t1

        if opt.vis_env:
            vis.plot('loss', loss[-1])

        # validate
        if opt.valid and (epoch + 1) % opt.valid_freq == 0:
            mapi2t, mapt2i = valid(generator, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader,
                                   query_labels, db_labels)
            print('...epoch: %3d, valid MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (epoch + 1, mapi2t, mapt2i))

            mapi2t_list.append(mapi2t)
            mapt2i_list.append(mapt2i)
            train_times.append(delta_t)

            if 0.5 * (mapi2t + mapt2i) > max_average:
                max_mapi2t = mapi2t
                max_mapt2i = mapt2i
                max_average = 0.5 * (mapi2t + mapt2i)
                save_model(generator)

            if opt.vis_env:
                vis.plot('mapi2t', mapi2t)
                vis.plot('mapt2i', mapt2i)

        if epoch % 100 == 0:
            for params in optimizer.param_groups:
                params['lr'] = max(params['lr'] * 0.8, 1e-6)

    if not opt.valid:
        save_model(generator)

    print('...training procedure finish')
    if opt.valid:
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (max_mapi2t, max_mapt2i))
    else:
        mapi2t, mapt2i = valid(generator, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader,
                               query_labels, db_labels)
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    with open(os.path.join(path, 'result.pkl'), 'wb') as f:
        pickle.dump([train_times, mapi2t_list, mapt2i_list], f)
コード例 #3
0
def test(**kwargs):
    opt.parse(kwargs)

    if opt.device is not None:
        opt.device = torch.device(opt.device)
    elif opt.gpus:
        opt.device = torch.device(0)
    else:
        opt.device = torch.device('cpu')

    pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    model = AGAH(opt.bit,
                 opt.tag_dim,
                 opt.num_label,
                 opt.emb_dim,
                 lambd=opt.lambd,
                 pretrain_model=pretrain_model).to(opt.device)

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    load_model(model, path)
    FEATURE_MAP = torch.load(os.path.join(path,
                                          'feature_map.pth')).to(opt.device)

    model.eval()

    images, tags, labels = load_data(opt.data_path, opt.dataset)

    x_query_data = Dataset(opt, images, tags, labels, test='image.query')
    x_db_data = Dataset(opt, images, tags, labels, test='image.db')
    y_query_data = Dataset(opt, images, tags, labels, test='text.query')
    y_db_data = Dataset(opt, images, tags, labels, test='text.db')

    x_query_dataloader = DataLoader(x_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    x_db_dataloader = DataLoader(x_db_data, opt.batch_size, shuffle=False)
    y_query_dataloader = DataLoader(y_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    y_db_dataloader = DataLoader(y_db_data, opt.batch_size, shuffle=False)

    qBX = generate_img_code(model, x_query_dataloader, opt.query_size,
                            FEATURE_MAP)
    qBY = generate_txt_code(model, y_query_dataloader, opt.query_size,
                            FEATURE_MAP)
    rBX = generate_img_code(model, x_db_dataloader, opt.db_size, FEATURE_MAP)
    rBY = generate_txt_code(model, y_db_dataloader, opt.db_size, FEATURE_MAP)

    query_labels, db_labels = x_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    p_i2t, r_i2t = pr_curve(qBX, rBY, query_labels, db_labels)
    p_t2i, r_t2i = pr_curve(qBY, rBX, query_labels, db_labels)

    K = [1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    pk_i2t = p_topK(qBX, rBY, query_labels, db_labels, K)
    pk_t2i = p_topK(qBY, rBX, query_labels, db_labels, K)

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    np.save(os.path.join(path, 'P_i2t.npy'), p_i2t.numpy())
    np.save(os.path.join(path, 'R_i2t.npy'), r_i2t.numpy())
    np.save(os.path.join(path, 'P_t2i.npy'), p_t2i.numpy())
    np.save(os.path.join(path, 'R_t2i.npy'), r_t2i.numpy())
    np.save(os.path.join(path, 'P_at_K_i2t.npy'), pk_i2t.numpy())
    np.save(os.path.join(path, 'P_at_K_t2i.npy'), pk_t2i.numpy())

    mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
    mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)
    print('...test MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))
コード例 #4
0
def train(**kwargs):
    opt.parse(kwargs)

    if opt.vis_env:
        vis = Visualizer(opt.vis_env, port=opt.vis_port)

    if opt.device is None or opt.device is 'cpu':
        opt.device = torch.device('cpu')
    else:
        opt.device = torch.device(opt.device)

    images, tags, labels = load_data(opt.data_path, type=opt.dataset)

    train_data = Dataset(opt, images, tags, labels)
    train_dataloader = DataLoader(train_data,
                                  batch_size=opt.batch_size,
                                  shuffle=True)

    # valid or test data
    x_query_data = Dataset(opt, images, tags, labels, test='image.query')
    x_db_data = Dataset(opt, images, tags, labels, test='image.db')
    y_query_data = Dataset(opt, images, tags, labels, test='text.query')
    y_db_data = Dataset(opt, images, tags, labels, test='text.db')

    x_query_dataloader = DataLoader(x_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    x_db_dataloader = DataLoader(x_db_data, opt.batch_size, shuffle=False)
    y_query_dataloader = DataLoader(y_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    y_db_dataloader = DataLoader(y_db_data, opt.batch_size, shuffle=False)

    query_labels, db_labels = x_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    if opt.load_model_path:
        pretrain_model = None
    elif opt.pretrain_model_path:
        pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    model = AGAH(opt.bit,
                 opt.tag_dim,
                 opt.num_label,
                 opt.emb_dim,
                 lambd=opt.lambd,
                 pretrain_model=pretrain_model).to(opt.device)

    load_model(model, opt.load_model_path)

    optimizer = Adamax([{
        'params': model.img_module.parameters(),
        'lr': opt.lr
    }, {
        'params': model.txt_module.parameters()
    }, {
        'params': model.hash_module.parameters()
    }, {
        'params': model.classifier.parameters()
    }],
                       lr=opt.lr * 10,
                       weight_decay=0.0005)

    optimizer_dis = {
        'img':
        Adamax(model.img_discriminator.parameters(),
               lr=opt.lr * 10,
               betas=(0.5, 0.9),
               weight_decay=0.0001),
        'txt':
        Adamax(model.txt_discriminator.parameters(),
               lr=opt.lr * 10,
               betas=(0.5, 0.9),
               weight_decay=0.0001)
    }

    criterion_tri_cos = TripletAllLoss(dis_metric='cos', reduction='sum')
    criterion_bce = nn.BCELoss(reduction='sum')

    loss = []

    max_mapi2t = 0.
    max_mapt2i = 0.

    FEATURE_I = torch.randn(opt.training_size, opt.emb_dim).to(opt.device)
    FEATURE_T = torch.randn(opt.training_size, opt.emb_dim).to(opt.device)

    U = torch.randn(opt.training_size, opt.bit).to(opt.device)
    V = torch.randn(opt.training_size, opt.bit).to(opt.device)

    FEATURE_MAP = torch.randn(opt.num_label, opt.emb_dim).to(opt.device)
    CODE_MAP = torch.sign(torch.randn(opt.num_label, opt.bit)).to(opt.device)

    train_labels = train_data.get_labels().to(opt.device)

    mapt2i_list = []
    mapi2t_list = []
    train_times = []

    for epoch in range(opt.max_epoch):
        t1 = time.time()
        for i, (ind, x, y, l) in tqdm(enumerate(train_dataloader)):
            imgs = x.to(opt.device)
            tags = y.to(opt.device)
            labels = l.to(opt.device)

            batch_size = len(ind)

            h_x, h_y, f_x, f_y, x_class, y_class = model(
                imgs, tags, FEATURE_MAP)

            FEATURE_I[ind] = f_x.data
            FEATURE_T[ind] = f_y.data
            U[ind] = h_x.data
            V[ind] = h_y.data

            #####
            # train txt discriminator
            #####
            D_txt_real = model.dis_txt(f_y.detach())
            D_txt_real = -D_txt_real.mean()
            optimizer_dis['txt'].zero_grad()
            D_txt_real.backward()

            # train with fake
            D_txt_fake = model.dis_txt(f_x.detach())
            D_txt_fake = D_txt_fake.mean()
            D_txt_fake.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.emb_dim).to(opt.device)
            interpolates = alpha * f_y.detach() + (1 - alpha) * f_x.detach()
            interpolates.requires_grad_()
            disc_interpolates = model.dis_txt(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates,
                                      inputs=interpolates,
                                      grad_outputs=torch.ones(
                                          disc_interpolates.size()).to(
                                              opt.device),
                                      create_graph=True,
                                      retain_graph=True,
                                      only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)
            # 10 is gradient penalty hyperparameter
            txt_gradient_penalty = (
                (gradients.norm(2, dim=1) - 1)**2).mean() * 10
            txt_gradient_penalty.backward()

            loss_D_txt = D_txt_real - D_txt_fake
            optimizer_dis['txt'].step()

            #####
            # train img discriminator
            #####
            D_img_real = model.dis_img(f_x.detach())
            D_img_real = -D_img_real.mean()
            optimizer_dis['img'].zero_grad()
            D_img_real.backward()

            # train with fake
            D_img_fake = model.dis_img(f_y.detach())
            D_img_fake = D_img_fake.mean()
            D_img_fake.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.emb_dim).to(opt.device)
            interpolates = alpha * f_x.detach() + (1 - alpha) * f_y.detach()
            interpolates.requires_grad_()
            disc_interpolates = model.dis_img(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates,
                                      inputs=interpolates,
                                      grad_outputs=torch.ones(
                                          disc_interpolates.size()).to(
                                              opt.device),
                                      create_graph=True,
                                      retain_graph=True,
                                      only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)
            # 10 is gradient penalty hyperparameter
            img_gradient_penalty = (
                (gradients.norm(2, dim=1) - 1)**2).mean() * 10
            img_gradient_penalty.backward()

            loss_D_img = D_img_real - D_img_fake
            optimizer_dis['img'].step()

            #####
            # train generators
            #####
            # update img network (to generate txt features)
            domain_output = model.dis_txt(f_x)
            loss_G_txt = -domain_output.mean()

            # update txt network (to generate img features)
            domain_output = model.dis_img(f_y)
            loss_G_img = -domain_output.mean()

            loss_adver = loss_G_txt + loss_G_img

            loss1 = criterion_tri_cos(h_x,
                                      labels,
                                      target=h_y,
                                      margin=opt.margin)
            loss2 = criterion_tri_cos(h_y,
                                      labels,
                                      target=h_x,
                                      margin=opt.margin)

            theta1 = F.cosine_similarity(torch.abs(h_x),
                                         torch.ones_like(h_x).to(opt.device))
            theta2 = F.cosine_similarity(torch.abs(h_y),
                                         torch.ones_like(h_y).to(opt.device))
            loss3 = torch.sum(1 / (1 + torch.exp(theta1))) + torch.sum(
                1 / (1 + torch.exp(theta2)))

            loss_class = criterion_bce(x_class, labels) + criterion_bce(
                y_class, labels)

            theta_code_x = h_x.mm(CODE_MAP.t())  # size: (batch, num_label)
            theta_code_y = h_y.mm(CODE_MAP.t())
            loss_code_map = torch.sum(torch.pow(theta_code_x - opt.bit * (labels * 2 - 1), 2)) + \
                            torch.sum(torch.pow(theta_code_y - opt.bit * (labels * 2 - 1), 2))

            loss_quant = torch.sum(torch.pow(
                h_x - torch.sign(h_x), 2)) + torch.sum(
                    torch.pow(h_y - torch.sign(h_y), 2))

            # err = loss1 + loss2 + loss3 + 0.5 * loss_class + 0.5 * (loss_f1 + loss_f2)
            err = loss1 + loss2 + opt.alpha * loss3 + opt.beta * loss_class + opt.gamma * loss_code_map + \
                  opt.eta * loss_quant + opt.mu * loss_adver

            optimizer.zero_grad()
            err.backward()
            optimizer.step()

            loss.append(err.item())

        CODE_MAP = update_code_map(U, V, CODE_MAP, train_labels)
        FEATURE_MAP = update_feature_map(FEATURE_I, FEATURE_T, train_labels)

        print('...epoch: %3d, loss: %3.3f' % (epoch + 1, loss[-1]))
        delta_t = time.time() - t1

        if opt.vis_env:
            vis.plot('loss', loss[-1])

        # validate
        if opt.valid and (epoch + 1) % opt.valid_freq == 0:
            mapi2t, mapt2i = valid(model, x_query_dataloader, x_db_dataloader,
                                   y_query_dataloader, y_db_dataloader,
                                   query_labels, db_labels, FEATURE_MAP)
            print(
                '...epoch: %3d, valid MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f'
                % (epoch + 1, mapi2t, mapt2i))

            mapi2t_list.append(mapi2t)
            mapt2i_list.append(mapt2i)
            train_times.append(delta_t)

            if opt.vis_env:
                d = {'mapi2t': mapi2t, 'mapt2i': mapt2i}
                vis.plot_many(d)

            if mapt2i >= max_mapt2i and mapi2t >= max_mapi2t:
                max_mapi2t = mapi2t
                max_mapt2i = mapt2i
                save_model(model)
                path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
                with torch.cuda.device(opt.device):
                    torch.save(FEATURE_MAP,
                               os.path.join(path, 'feature_map.pth'))

        if epoch % 100 == 0:
            for params in optimizer.param_groups:
                params['lr'] = max(params['lr'] * 0.6, 1e-6)

    if not opt.valid:
        save_model(model)

    print('...training procedure finish')
    if opt.valid:
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' %
              (max_mapi2t, max_mapt2i))
    else:
        mapi2t, mapt2i = valid(model, x_query_dataloader, x_db_dataloader,
                               y_query_dataloader, y_db_dataloader,
                               query_labels, db_labels, FEATURE_MAP)
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' %
              (mapi2t, mapt2i))

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    with open(os.path.join(path, 'result.pkl'), 'wb') as f:
        pickle.dump([train_times, mapi2t_list, mapt2i_list], f)
コード例 #5
0
ファイル: main.py プロジェクト: WangBenHui/cpah
def test(**kwargs):
    opt.parse(kwargs)

    if opt.device is not None:
        opt.device = torch.device(opt.device)
    elif opt.gpus:
        opt.device = torch.device(0)
    else:
        opt.device = torch.device('cpu')

    with torch.no_grad():
        model = CPAH(opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, opt.num_label).to(opt.device)

        path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc)
        load_model(model, path)

        model.eval()

        images, tags, labels = load_data(opt.data_path, opt.dataset)

        i_query_data = Dataset(opt, images, tags, labels, test='image.query')
        i_db_data = Dataset(opt, images, tags, labels, test='image.db')
        t_query_data = Dataset(opt, images, tags, labels, test='text.query')
        t_db_data = Dataset(opt, images, tags, labels, test='text.db')

        i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False)
        i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False)
        t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False)
        t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False)

        qBX = generate_img_code(model, i_query_dataloader, opt.query_size)
        qBY = generate_txt_code(model, t_query_dataloader, opt.query_size)
        rBX = generate_img_code(model, i_db_dataloader, opt.db_size)
        rBY = generate_txt_code(model, t_db_dataloader, opt.db_size)

        query_labels, db_labels = i_query_data.get_labels()
        query_labels = query_labels.to(opt.device)
        db_labels = db_labels.to(opt.device)

        #K = [1, 10, 100, 1000]
        #p_top_k(qBX, rBY, query_labels, db_labels, K, tqdm_label='I2T')
        # pr_curve2(qBY, rBX, query_labels, db_labels)

        p_i2t, r_i2t = pr_curve(qBX, rBY, query_labels, db_labels, tqdm_label='I2T')
        p_t2i, r_t2i = pr_curve(qBY, rBX, query_labels, db_labels, tqdm_label='T2I')
        p_i2i, r_i2i = pr_curve(qBX, rBX, query_labels, db_labels, tqdm_label='I2I')
        p_t2t, r_t2t = pr_curve(qBY, rBY, query_labels, db_labels, tqdm_label='T2T')

        K = [1, 10, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]
        pk_i2t = p_top_k(qBX, rBY, query_labels, db_labels, K, tqdm_label='I2T')
        pk_t2i = p_top_k(qBY, rBX, query_labels, db_labels, K, tqdm_label='T2I')
        pk_i2i = p_top_k(qBX, rBX, query_labels, db_labels, K, tqdm_label='I2I')
        pk_t2t = p_top_k(qBY, rBY, query_labels, db_labels, K, tqdm_label='T2T')

        mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
        mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)
        mapi2i = calc_map_k(qBX, rBX, query_labels, db_labels)
        mapt2t = calc_map_k(qBY, rBY, query_labels, db_labels)

        pr_dict = {'pi2t': p_i2t.cpu().numpy(), 'ri2t': r_i2t.cpu().numpy(),
                   'pt2i': p_t2i.cpu().numpy(), 'rt2i': r_t2i.cpu().numpy(),
                   'pi2i': p_i2i.cpu().numpy(), 'ri2i': r_i2i.cpu().numpy(),
                   'pt2t': p_t2t.cpu().numpy(), 'rt2t': r_t2t.cpu().numpy()}

        pk_dict = {'k': K,
                   'pki2t': pk_i2t.cpu().numpy(),
                   'pkt2i': pk_t2i.cpu().numpy(),
                   'pki2i': pk_i2i.cpu().numpy(),
                   'pkt2t': pk_t2t.cpu().numpy()}

        map_dict = {'mapi2t': float(mapi2t.cpu().numpy()),
                    'mapt2i': float(mapt2i.cpu().numpy()),
                    'mapi2i': float(mapi2i.cpu().numpy()),
                    'mapt2t': float(mapt2t.cpu().numpy())}

        print('   Test MAP: MAP(i->t) = {:3.4f}, MAP(t->i) = {:3.4f}, MAP(i->i) = {:3.4f}, MAP(t->t) = {:3.4f}'.format(mapi2t, mapt2i, mapi2i, mapt2t))

        path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc)
        write_pickle(os.path.join(path, 'pr_dict.pkl'), pr_dict)
        write_pickle(os.path.join(path, 'pk_dict.pkl'), pk_dict)
        write_pickle(os.path.join(path, 'map_dict.pkl'), map_dict)
コード例 #6
0
ファイル: main.py プロジェクト: WangBenHui/cpah
def train(**kwargs):
    since = time.time()
    opt.parse(kwargs)

    if (opt.device is None) or (opt.device == 'cpu'):
        opt.device = torch.device('cpu')
    else:
        opt.device = torch.device(opt.device)

    images, tags, labels = load_data(opt.data_path, type=opt.dataset)
    train_data = Dataset(opt, images, tags, labels)
    train_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True)
    L = train_data.get_labels()
    L = L.to(opt.device)
    # test
    i_query_data = Dataset(opt, images, tags, labels, test='image.query')
    i_db_data = Dataset(opt, images, tags, labels, test='image.db')
    t_query_data = Dataset(opt, images, tags, labels, test='text.query')
    t_db_data = Dataset(opt, images, tags, labels, test='text.db')

    i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False)
    i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False)
    t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False)
    t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False)

    query_labels, db_labels = i_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    model = CPAH(opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, opt.num_label).to(opt.device)

    # discriminator = DisModel(opt.hidden_dim, opt.num_label).to(opt.device)

    optimizer_gen = Adam([
        {'params': model.image_module.parameters()},
        {'params': model.text_module.parameters()},
        {'params': model.hash_module.parameters()},
        {'params': model.mask_module.parameters()},
        {'params': model.consistency_dis.parameters()},
        {'params': model.classifier.parameters()},
    ], lr=opt.lr, weight_decay=0.0005)

    optimizer_dis = Adam(model.feature_dis.parameters(), lr=opt.lr, betas=(0.5, 0.9), weight_decay=0.0001)

    #tri_loss = TripletLoss(opt, reduction='sum')
    loss_bce = torch.nn.BCELoss(reduction='sum')
    loss_ce = torch.nn.CrossEntropyLoss(reduction='sum')

    loss = []
    losses = []

    max_mapi2t = 0.
    max_mapt2i = 0.
    max_mapi2i = 0.
    max_mapt2t = 0.
    max_average = 0.

    mapt2i_list = []
    mapi2t_list = []
    mapi2i_list = []
    mapt2t_list = []
    train_times = []

    B = torch.randn(opt.training_size, opt.bit).sign().to(opt.device)

    H_i = torch.zeros(opt.training_size, opt.bit).to(opt.device)
    H_t = torch.zeros(opt.training_size, opt.bit).to(opt.device)

    torch.autograd.set_detect_anomaly(True)

    for epoch in range(opt.max_epoch):
        t1 = time.time()
        e_loss = 0
        e_losses = {'adv': 0, 'class': 0, 'quant': 0, 'pairwise': 0}
        # for i, (ind, img, txt, label) in tqdm(enumerate(train_dataloader)):
        for i, (ind, img, txt, label) in enumerate(train_dataloader):
            #print(i)
            imgs = img.to(opt.device)
            txt = txt.to(opt.device)
            labels = label.to(opt.device)

            batch_size = len(ind)

            h_img, h_txt, f_rc_img, f_rc_txt, f_rp_img, f_rp_txt = model(imgs, txt)

            H_i[ind, :] = h_img
            H_t[ind, :] = h_txt

            ###################################
            # train discriminator. CPAH paper: (5)
            ###################################
            # IMG - real, TXT - fake
            # train with real (IMG)
            optimizer_dis.zero_grad()

            d_real = model.dis_D(f_rc_img.detach())
            d_real = -torch.log(torch.sigmoid(d_real)).mean()
            d_real.backward()

            # train with fake (TXT)
            d_fake = model.dis_D(f_rc_txt.detach())
            d_fake = -torch.log(torch.ones(batch_size).to(opt.device) - torch.sigmoid(d_fake)).mean()
            d_fake.backward()

            # train with gradient penalty (GP)
            # interpolate real and fake data
            alpha = torch.rand(batch_size, opt.hidden_dim).to(opt.device)
            interpolates = alpha * f_rc_img.detach() + (1 - alpha) * f_rc_txt.detach()
            interpolates.requires_grad_()
            disc_interpolates = model.dis_D(interpolates)
            # get gradients with respect to inputs
            gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                      grad_outputs=torch.ones(disc_interpolates.size()).to(opt.device),
                                      create_graph=True, retain_graph=True, only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)
            # calculate penalty
            gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10  # 10 is GP hyperparameter
            gradient_penalty.backward()

            optimizer_dis.step()

            ###################################
            # train generator
            ###################################

            # adversarial loss, CPAH paper: (6)
            loss_adver = -torch.log(torch.sigmoid(model.dis_D(f_rc_txt))).mean()  # don't detach from graph

            # consistency classification loss, CPAH paper: (7)
            f_r = torch.vstack([f_rc_img, f_rc_txt, f_rp_img, f_rp_txt])
            l_r = [1] * len(ind) * 2 + [0] * len(ind) + [2] * len(ind)  # labels
            l_r = torch.tensor(l_r).to(opt.device)
            loss_consistency_class = loss_ce(f_r, l_r)

            # classification loss, CPAH paper: (8)
            l_f_rc_img = model.dis_classify(f_rc_img, 'img')
            l_f_rc_txt = model.dis_classify(f_rc_txt, 'txt')
            loss_class = loss_bce(l_f_rc_img, labels) + loss_bce(l_f_rc_txt, labels)
            #loss_class = torch.tensor(0).to(opt.device)

            # pairwise loss, CPAH paper: (10)
            S = (labels.mm(labels.T) > 0).float()
            # theta = 0.5 * ((h_img.mm(h_txt.T) + h_txt.mm(h_img.T)) / 2)  # not completely sure
            theta = 0.5 * h_img.mm(h_txt.T)
            #theta.retain_grad()
            #theta.register_hook(lambda x: print("theta  :", torch.max(x), torch.min(x), torch.mean(x)))
            e_theta = torch.exp(theta)
            #e_theta.retain_grad()
            #e_theta.register_hook(lambda x: print("theta  :", torch.max(x), torch.min(x), torch.mean(x)))
            loss_pairwise = -torch.sum(S*theta - torch.log(1 + e_theta))

            # quantization loss, CPAH paper: (11)
            loss_quant = torch.sum(torch.pow(B[ind, :] - h_img, 2)) + torch.sum(torch.pow(B[ind, :] - h_txt, 2))
            #loss_quant = torch.tensor(0).to(opt.device)

            err = 100 * loss_adver + opt.alpha * (loss_consistency_class + loss_class) + loss_pairwise + opt.beta * loss_quant

            e_losses['adv'] += 100 * loss_adver.detach().cpu().numpy()
            e_losses['class'] += (opt.alpha * (loss_consistency_class + loss_class)).detach().cpu().numpy()
            e_losses['pairwise'] += loss_pairwise.detach().cpu().numpy()
            e_losses['quant'] += loss_quant.detach().cpu().numpy()

            optimizer_gen.zero_grad()
            err.backward()
            optimizer_gen.step()

            e_loss = err + e_loss

        loss.append(e_loss.item())
        e_losses['sum'] = sum(e_losses.values())
        losses.append(e_losses)

        B = (0.5 * (H_i.detach() + H_t.detach())).sign()

        delta_t = time.time() - t1
        print('Epoch: {:4d}/{:4d}, time, {:3.3f}s, loss: {:15.3f},'.format(epoch + 1, opt.max_epoch, delta_t,
                                                                           loss[-1]) + 5 * ' ' + 'losses:', e_losses)
        # validate
        if opt.valid and (epoch + 1) % opt.valid_freq == 0:
            mapi2t, mapt2i, mapi2i, mapt2t = valid(model, i_query_dataloader, i_db_dataloader, t_query_dataloader,
                                                   t_db_dataloader, query_labels, db_labels)
            print(
                'Epoch: {:4d}/{:4d}, validation MAP: MAP(i->t) = {:3.4f}, MAP(t->i) = {:3.4f}, MAP(i->i) = {:3.4f}, MAP(t->t) = {:3.4f}'.format(
                    epoch + 1, opt.max_epoch, mapi2t, mapt2i, mapi2i, mapt2t))

            mapi2t_list.append(mapi2t)
            mapt2i_list.append(mapt2i)
            mapi2i_list.append(mapi2i)
            mapt2t_list.append(mapt2t)
            train_times.append(delta_t)

            if 0.5 * (mapi2t + mapt2i) > max_average:
                max_mapi2t = mapi2t
                max_mapt2i = mapt2i
                max_mapi2i = mapi2i
                max_mapt2t = mapt2t
                max_average = 0.5 * (mapi2t + mapt2i)
                save_model(model)
                path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc)
                with torch.cuda.device(opt.device):
                    torch.save([H_i, H_t], os.path.join(path, 'hash_maps_i_t.pth'))
                with torch.cuda.device(opt.device):
                    torch.save(B, os.path.join(path, 'code_map.pth'))

        # decrease the lr to its one fifth every 30 epochs
        if epoch % 30 == 0:
            for params in optimizer_gen.param_groups:
                params['lr'] = max(params['lr'] * 0.2, 1e-6)

        if epoch % 100 == 0:
            pass

    if not opt.valid:
        save_model(model)

    time_elapsed = time.time() - since
    print('\n   Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    if opt.valid:
        print('   Max MAP: MAP(i->t) = {:3.4f}, MAP(t->i) = {:3.4f}, MAP(i->i) = {:3.4f}, MAP(t->t) = {:3.4f}'.format(
            max_mapi2t, max_mapt2i, max_mapi2i, max_mapt2t))
    else:
        mapi2t, mapt2i, mapi2i, mapt2t = valid(model, i_query_dataloader, i_db_dataloader, t_query_dataloader,
                                               t_db_dataloader, query_labels, db_labels)
        print('   Max MAP: MAP(i->t) = {:3.4f}, MAP(t->i) = {:3.4f}, MAP(i->i) = {:3.4f}, MAP(t->t) = {:3.4f}'.format(
            mapi2t, mapt2i, mapi2i, mapt2t))

    res_dict = {'mapi2t': mapi2t_list,
                'mapt2i': mapt2i_list,
                'mapi2i': mapi2i_list,
                'mapt2t': mapt2t_list,
                'epoch_times': train_times,
                'losses': losses}

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc)
    write_pickle(os.path.join(path, 'res_dict.pkl'), res_dict)