def test(**kwargs): opt.parse(kwargs) if opt.device is not None: opt.device = torch.device(opt.device) elif opt.gpus: opt.device = torch.device(0) else: opt.device = torch.device('cpu') pretrain_model = load_pretrain_model(opt.pretrain_model_path) generator = GEN(opt.dropout, opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, pretrain_model=pretrain_model).to(opt.device) path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) load_model(generator, path) generator.eval() images, tags, labels = load_data(opt.data_path, opt.dataset) i_query_data = Dataset(opt, images, tags, labels, test='image.query') i_db_data = Dataset(opt, images, tags, labels, test='image.db') t_query_data = Dataset(opt, images, tags, labels, test='text.query') t_db_data = Dataset(opt, images, tags, labels, test='text.db') i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False) i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False) t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False) t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False) qBX = generate_img_code(generator, i_query_dataloader, opt.query_size) qBY = generate_txt_code(generator, t_query_dataloader, opt.query_size) rBX = generate_img_code(generator, i_db_dataloader, opt.db_size) rBY = generate_txt_code(generator, t_db_dataloader, opt.db_size) query_labels, db_labels = i_query_data.get_labels() query_labels = query_labels.to(opt.device) db_labels = db_labels.to(opt.device) mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels) mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels) print('...test MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))
def train(**kwargs): opt.parse(kwargs) if opt.vis_env: vis = Visualizer(opt.vis_env, port=opt.vis_port) if opt.device is None or opt.device is 'cpu': opt.device = torch.device('cpu') else: opt.device = torch.device(opt.device) images, tags, labels = load_data(opt.data_path, type=opt.dataset) train_data = Dataset(opt, images, tags, labels) train_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True) L = train_data.get_labels() L = L.to(opt.device) # test i_query_data = Dataset(opt, images, tags, labels, test='image.query') i_db_data = Dataset(opt, images, tags, labels, test='image.db') t_query_data = Dataset(opt, images, tags, labels, test='text.query') t_db_data = Dataset(opt, images, tags, labels, test='text.db') i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False) i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False) t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False) t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False) query_labels, db_labels = i_query_data.get_labels() query_labels = query_labels.to(opt.device) db_labels = db_labels.to(opt.device) pretrain_model = load_pretrain_model(opt.pretrain_model_path) generator = GEN(opt.dropout, opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, opt.num_label, pretrain_model=pretrain_model).to(opt.device) discriminator = DIS(opt.hidden_dim//4, opt.hidden_dim//8, opt.bit).to(opt.device) optimizer = Adam([ # {'params': generator.cnn_f.parameters()}, ## froze parameters of cnn_f {'params': generator.image_module.parameters()}, {'params': generator.text_module.parameters()}, {'params': generator.hash_module.parameters()} ], lr=opt.lr, weight_decay=0.0005) optimizer_dis = { 'feature': Adam(discriminator.feature_dis.parameters(), lr=opt.lr, betas=(0.5, 0.9), weight_decay=0.0001), 'hash': Adam(discriminator.hash_dis.parameters(), lr=opt.lr, betas=(0.5, 0.9), weight_decay=0.0001) } tri_loss = TripletLoss(opt, reduction='sum') loss = [] max_mapi2t = 0. max_mapt2i = 0. max_average = 0. mapt2i_list = [] mapi2t_list = [] train_times = [] B_i = torch.randn(opt.training_size, opt.bit).sign().to(opt.device) B_t = B_i H_i = torch.zeros(opt.training_size, opt.bit).to(opt.device) H_t = torch.zeros(opt.training_size, opt.bit).to(opt.device) for epoch in range(opt.max_epoch): t1 = time.time() e_loss = 0 for i, (ind, img, txt, label) in tqdm(enumerate(train_dataloader)): imgs = img.to(opt.device) txt = txt.to(opt.device) labels = label.to(opt.device) batch_size = len(ind) h_i, h_t, f_i, f_t = generator(imgs, txt) H_i[ind, :] = h_i.data H_t[ind, :] = h_t.data h_t_detach = generator.generate_txt_code(txt) ##### # train feature discriminator ##### D_real_feature = discriminator.dis_feature(f_i.detach()) D_real_feature = -opt.gamma * torch.log(torch.sigmoid(D_real_feature)).mean() # D_real_feature = -D_real_feature.mean() optimizer_dis['feature'].zero_grad() D_real_feature.backward() # train with fake D_fake_feature = discriminator.dis_feature(f_t.detach()) D_fake_feature = -opt.gamma * torch.log(torch.ones(batch_size).to(opt.device) - torch.sigmoid(D_fake_feature)).mean() # D_fake_feature = D_fake_feature.mean() D_fake_feature.backward() # train with gradient penalty alpha = torch.rand(batch_size, opt.hidden_dim//4).to(opt.device) interpolates = alpha * f_i.detach() + (1 - alpha) * f_t.detach() interpolates.requires_grad_() disc_interpolates = discriminator.dis_feature(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).to(opt.device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) # 10 is gradient penalty hyperparameter feature_gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10 feature_gradient_penalty.backward() optimizer_dis['feature'].step() ##### # train hash discriminator ##### D_real_hash = discriminator.dis_hash(h_i.detach()) D_real_hash = -opt.gamma * torch.log(torch.sigmoid(D_real_hash)).mean() optimizer_dis['hash'].zero_grad() D_real_hash.backward() # train with fake D_fake_hash = discriminator.dis_hash(h_t.detach()) D_fake_hash = -opt.gamma * torch.log(torch.ones(batch_size).to(opt.device) - torch.sigmoid(D_fake_hash)).mean() D_fake_hash.backward() # train with gradient penalty alpha = torch.rand(batch_size, opt.bit).to(opt.device) interpolates = alpha * h_i.detach() + (1 - alpha) * h_t.detach() interpolates.requires_grad_() disc_interpolates = discriminator.dis_hash(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).to(opt.device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) hash_gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10 hash_gradient_penalty.backward() optimizer_dis['hash'].step() loss_G_txt_feature = -torch.log(torch.sigmoid(discriminator.dis_feature(f_t))).mean() loss_adver_feature = loss_G_txt_feature loss_G_txt_hash = -torch.log(torch.sigmoid(discriminator.dis_hash(h_t_detach))).mean() loss_adver_hash = loss_G_txt_hash tri_i2t = tri_loss(h_i, labels, target=h_t, margin=opt.margin) tri_t2i = tri_loss(h_t, labels, target=h_i, margin=opt.margin) weighted_cos_tri = tri_i2t + tri_t2i i_ql = torch.sum(torch.pow(B_i[ind, :] - h_i, 2)) t_ql = torch.sum(torch.pow(B_t[ind, :] - h_t, 2)) loss_quant = i_ql + t_ql err = opt.alpha * weighted_cos_tri + \ opt.beta * loss_quant + opt.gamma * (loss_adver_feature + loss_adver_hash) optimizer.zero_grad() err.backward() optimizer.step() e_loss = err + e_loss P_i = torch.inverse( L.t() @ L + opt.lamb * torch.eye(opt.num_label, device=opt.device)) @ L.t() @ B_i P_t = torch.inverse( L.t() @ L + opt.lamb * torch.eye(opt.num_label, device=opt.device)) @ L.t() @ B_t B_i = (L @ P_i + opt.mu * H_i).sign() B_t = (L @ P_t + opt.mu * H_t).sign() loss.append(e_loss.item()) print('...epoch: %3d, loss: %3.3f' % (epoch + 1, loss[-1])) delta_t = time.time() - t1 if opt.vis_env: vis.plot('loss', loss[-1]) # validate if opt.valid and (epoch + 1) % opt.valid_freq == 0: mapi2t, mapt2i = valid(generator, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader, query_labels, db_labels) print('...epoch: %3d, valid MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (epoch + 1, mapi2t, mapt2i)) mapi2t_list.append(mapi2t) mapt2i_list.append(mapt2i) train_times.append(delta_t) if 0.5 * (mapi2t + mapt2i) > max_average: max_mapi2t = mapi2t max_mapt2i = mapt2i max_average = 0.5 * (mapi2t + mapt2i) save_model(generator) if opt.vis_env: vis.plot('mapi2t', mapi2t) vis.plot('mapt2i', mapt2i) if epoch % 100 == 0: for params in optimizer.param_groups: params['lr'] = max(params['lr'] * 0.8, 1e-6) if not opt.valid: save_model(generator) print('...training procedure finish') if opt.valid: print(' max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (max_mapi2t, max_mapt2i)) else: mapi2t, mapt2i = valid(generator, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader, query_labels, db_labels) print(' max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i)) path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) with open(os.path.join(path, 'result.pkl'), 'wb') as f: pickle.dump([train_times, mapi2t_list, mapt2i_list], f)
def test(**kwargs): opt.parse(kwargs) if opt.device is not None: opt.device = torch.device(opt.device) elif opt.gpus: opt.device = torch.device(0) else: opt.device = torch.device('cpu') pretrain_model = load_pretrain_model(opt.pretrain_model_path) model = AGAH(opt.bit, opt.tag_dim, opt.num_label, opt.emb_dim, lambd=opt.lambd, pretrain_model=pretrain_model).to(opt.device) path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) load_model(model, path) FEATURE_MAP = torch.load(os.path.join(path, 'feature_map.pth')).to(opt.device) model.eval() images, tags, labels = load_data(opt.data_path, opt.dataset) x_query_data = Dataset(opt, images, tags, labels, test='image.query') x_db_data = Dataset(opt, images, tags, labels, test='image.db') y_query_data = Dataset(opt, images, tags, labels, test='text.query') y_db_data = Dataset(opt, images, tags, labels, test='text.db') x_query_dataloader = DataLoader(x_query_data, opt.batch_size, shuffle=False) x_db_dataloader = DataLoader(x_db_data, opt.batch_size, shuffle=False) y_query_dataloader = DataLoader(y_query_data, opt.batch_size, shuffle=False) y_db_dataloader = DataLoader(y_db_data, opt.batch_size, shuffle=False) qBX = generate_img_code(model, x_query_dataloader, opt.query_size, FEATURE_MAP) qBY = generate_txt_code(model, y_query_dataloader, opt.query_size, FEATURE_MAP) rBX = generate_img_code(model, x_db_dataloader, opt.db_size, FEATURE_MAP) rBY = generate_txt_code(model, y_db_dataloader, opt.db_size, FEATURE_MAP) query_labels, db_labels = x_query_data.get_labels() query_labels = query_labels.to(opt.device) db_labels = db_labels.to(opt.device) p_i2t, r_i2t = pr_curve(qBX, rBY, query_labels, db_labels) p_t2i, r_t2i = pr_curve(qBY, rBX, query_labels, db_labels) K = [1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] pk_i2t = p_topK(qBX, rBY, query_labels, db_labels, K) pk_t2i = p_topK(qBY, rBX, query_labels, db_labels, K) path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) np.save(os.path.join(path, 'P_i2t.npy'), p_i2t.numpy()) np.save(os.path.join(path, 'R_i2t.npy'), r_i2t.numpy()) np.save(os.path.join(path, 'P_t2i.npy'), p_t2i.numpy()) np.save(os.path.join(path, 'R_t2i.npy'), r_t2i.numpy()) np.save(os.path.join(path, 'P_at_K_i2t.npy'), pk_i2t.numpy()) np.save(os.path.join(path, 'P_at_K_t2i.npy'), pk_t2i.numpy()) mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels) mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels) print('...test MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))
def train(**kwargs): opt.parse(kwargs) if opt.vis_env: vis = Visualizer(opt.vis_env, port=opt.vis_port) if opt.device is None or opt.device is 'cpu': opt.device = torch.device('cpu') else: opt.device = torch.device(opt.device) images, tags, labels = load_data(opt.data_path, type=opt.dataset) train_data = Dataset(opt, images, tags, labels) train_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True) # valid or test data x_query_data = Dataset(opt, images, tags, labels, test='image.query') x_db_data = Dataset(opt, images, tags, labels, test='image.db') y_query_data = Dataset(opt, images, tags, labels, test='text.query') y_db_data = Dataset(opt, images, tags, labels, test='text.db') x_query_dataloader = DataLoader(x_query_data, opt.batch_size, shuffle=False) x_db_dataloader = DataLoader(x_db_data, opt.batch_size, shuffle=False) y_query_dataloader = DataLoader(y_query_data, opt.batch_size, shuffle=False) y_db_dataloader = DataLoader(y_db_data, opt.batch_size, shuffle=False) query_labels, db_labels = x_query_data.get_labels() query_labels = query_labels.to(opt.device) db_labels = db_labels.to(opt.device) if opt.load_model_path: pretrain_model = None elif opt.pretrain_model_path: pretrain_model = load_pretrain_model(opt.pretrain_model_path) model = AGAH(opt.bit, opt.tag_dim, opt.num_label, opt.emb_dim, lambd=opt.lambd, pretrain_model=pretrain_model).to(opt.device) load_model(model, opt.load_model_path) optimizer = Adamax([{ 'params': model.img_module.parameters(), 'lr': opt.lr }, { 'params': model.txt_module.parameters() }, { 'params': model.hash_module.parameters() }, { 'params': model.classifier.parameters() }], lr=opt.lr * 10, weight_decay=0.0005) optimizer_dis = { 'img': Adamax(model.img_discriminator.parameters(), lr=opt.lr * 10, betas=(0.5, 0.9), weight_decay=0.0001), 'txt': Adamax(model.txt_discriminator.parameters(), lr=opt.lr * 10, betas=(0.5, 0.9), weight_decay=0.0001) } criterion_tri_cos = TripletAllLoss(dis_metric='cos', reduction='sum') criterion_bce = nn.BCELoss(reduction='sum') loss = [] max_mapi2t = 0. max_mapt2i = 0. FEATURE_I = torch.randn(opt.training_size, opt.emb_dim).to(opt.device) FEATURE_T = torch.randn(opt.training_size, opt.emb_dim).to(opt.device) U = torch.randn(opt.training_size, opt.bit).to(opt.device) V = torch.randn(opt.training_size, opt.bit).to(opt.device) FEATURE_MAP = torch.randn(opt.num_label, opt.emb_dim).to(opt.device) CODE_MAP = torch.sign(torch.randn(opt.num_label, opt.bit)).to(opt.device) train_labels = train_data.get_labels().to(opt.device) mapt2i_list = [] mapi2t_list = [] train_times = [] for epoch in range(opt.max_epoch): t1 = time.time() for i, (ind, x, y, l) in tqdm(enumerate(train_dataloader)): imgs = x.to(opt.device) tags = y.to(opt.device) labels = l.to(opt.device) batch_size = len(ind) h_x, h_y, f_x, f_y, x_class, y_class = model( imgs, tags, FEATURE_MAP) FEATURE_I[ind] = f_x.data FEATURE_T[ind] = f_y.data U[ind] = h_x.data V[ind] = h_y.data ##### # train txt discriminator ##### D_txt_real = model.dis_txt(f_y.detach()) D_txt_real = -D_txt_real.mean() optimizer_dis['txt'].zero_grad() D_txt_real.backward() # train with fake D_txt_fake = model.dis_txt(f_x.detach()) D_txt_fake = D_txt_fake.mean() D_txt_fake.backward() # train with gradient penalty alpha = torch.rand(batch_size, opt.emb_dim).to(opt.device) interpolates = alpha * f_y.detach() + (1 - alpha) * f_x.detach() interpolates.requires_grad_() disc_interpolates = model.dis_txt(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones( disc_interpolates.size()).to( opt.device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) # 10 is gradient penalty hyperparameter txt_gradient_penalty = ( (gradients.norm(2, dim=1) - 1)**2).mean() * 10 txt_gradient_penalty.backward() loss_D_txt = D_txt_real - D_txt_fake optimizer_dis['txt'].step() ##### # train img discriminator ##### D_img_real = model.dis_img(f_x.detach()) D_img_real = -D_img_real.mean() optimizer_dis['img'].zero_grad() D_img_real.backward() # train with fake D_img_fake = model.dis_img(f_y.detach()) D_img_fake = D_img_fake.mean() D_img_fake.backward() # train with gradient penalty alpha = torch.rand(batch_size, opt.emb_dim).to(opt.device) interpolates = alpha * f_x.detach() + (1 - alpha) * f_y.detach() interpolates.requires_grad_() disc_interpolates = model.dis_img(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones( disc_interpolates.size()).to( opt.device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) # 10 is gradient penalty hyperparameter img_gradient_penalty = ( (gradients.norm(2, dim=1) - 1)**2).mean() * 10 img_gradient_penalty.backward() loss_D_img = D_img_real - D_img_fake optimizer_dis['img'].step() ##### # train generators ##### # update img network (to generate txt features) domain_output = model.dis_txt(f_x) loss_G_txt = -domain_output.mean() # update txt network (to generate img features) domain_output = model.dis_img(f_y) loss_G_img = -domain_output.mean() loss_adver = loss_G_txt + loss_G_img loss1 = criterion_tri_cos(h_x, labels, target=h_y, margin=opt.margin) loss2 = criterion_tri_cos(h_y, labels, target=h_x, margin=opt.margin) theta1 = F.cosine_similarity(torch.abs(h_x), torch.ones_like(h_x).to(opt.device)) theta2 = F.cosine_similarity(torch.abs(h_y), torch.ones_like(h_y).to(opt.device)) loss3 = torch.sum(1 / (1 + torch.exp(theta1))) + torch.sum( 1 / (1 + torch.exp(theta2))) loss_class = criterion_bce(x_class, labels) + criterion_bce( y_class, labels) theta_code_x = h_x.mm(CODE_MAP.t()) # size: (batch, num_label) theta_code_y = h_y.mm(CODE_MAP.t()) loss_code_map = torch.sum(torch.pow(theta_code_x - opt.bit * (labels * 2 - 1), 2)) + \ torch.sum(torch.pow(theta_code_y - opt.bit * (labels * 2 - 1), 2)) loss_quant = torch.sum(torch.pow( h_x - torch.sign(h_x), 2)) + torch.sum( torch.pow(h_y - torch.sign(h_y), 2)) # err = loss1 + loss2 + loss3 + 0.5 * loss_class + 0.5 * (loss_f1 + loss_f2) err = loss1 + loss2 + opt.alpha * loss3 + opt.beta * loss_class + opt.gamma * loss_code_map + \ opt.eta * loss_quant + opt.mu * loss_adver optimizer.zero_grad() err.backward() optimizer.step() loss.append(err.item()) CODE_MAP = update_code_map(U, V, CODE_MAP, train_labels) FEATURE_MAP = update_feature_map(FEATURE_I, FEATURE_T, train_labels) print('...epoch: %3d, loss: %3.3f' % (epoch + 1, loss[-1])) delta_t = time.time() - t1 if opt.vis_env: vis.plot('loss', loss[-1]) # validate if opt.valid and (epoch + 1) % opt.valid_freq == 0: mapi2t, mapt2i = valid(model, x_query_dataloader, x_db_dataloader, y_query_dataloader, y_db_dataloader, query_labels, db_labels, FEATURE_MAP) print( '...epoch: %3d, valid MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (epoch + 1, mapi2t, mapt2i)) mapi2t_list.append(mapi2t) mapt2i_list.append(mapt2i) train_times.append(delta_t) if opt.vis_env: d = {'mapi2t': mapi2t, 'mapt2i': mapt2i} vis.plot_many(d) if mapt2i >= max_mapt2i and mapi2t >= max_mapi2t: max_mapi2t = mapi2t max_mapt2i = mapt2i save_model(model) path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) with torch.cuda.device(opt.device): torch.save(FEATURE_MAP, os.path.join(path, 'feature_map.pth')) if epoch % 100 == 0: for params in optimizer.param_groups: params['lr'] = max(params['lr'] * 0.6, 1e-6) if not opt.valid: save_model(model) print('...training procedure finish') if opt.valid: print(' max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (max_mapi2t, max_mapt2i)) else: mapi2t, mapt2i = valid(model, x_query_dataloader, x_db_dataloader, y_query_dataloader, y_db_dataloader, query_labels, db_labels, FEATURE_MAP) print(' max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i)) path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) with open(os.path.join(path, 'result.pkl'), 'wb') as f: pickle.dump([train_times, mapi2t_list, mapt2i_list], f)
def test(**kwargs): opt.parse(kwargs) if opt.device is not None: opt.device = torch.device(opt.device) elif opt.gpus: opt.device = torch.device(0) else: opt.device = torch.device('cpu') with torch.no_grad(): model = CPAH(opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, opt.num_label).to(opt.device) path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc) load_model(model, path) model.eval() images, tags, labels = load_data(opt.data_path, opt.dataset) i_query_data = Dataset(opt, images, tags, labels, test='image.query') i_db_data = Dataset(opt, images, tags, labels, test='image.db') t_query_data = Dataset(opt, images, tags, labels, test='text.query') t_db_data = Dataset(opt, images, tags, labels, test='text.db') i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False) i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False) t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False) t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False) qBX = generate_img_code(model, i_query_dataloader, opt.query_size) qBY = generate_txt_code(model, t_query_dataloader, opt.query_size) rBX = generate_img_code(model, i_db_dataloader, opt.db_size) rBY = generate_txt_code(model, t_db_dataloader, opt.db_size) query_labels, db_labels = i_query_data.get_labels() query_labels = query_labels.to(opt.device) db_labels = db_labels.to(opt.device) #K = [1, 10, 100, 1000] #p_top_k(qBX, rBY, query_labels, db_labels, K, tqdm_label='I2T') # pr_curve2(qBY, rBX, query_labels, db_labels) p_i2t, r_i2t = pr_curve(qBX, rBY, query_labels, db_labels, tqdm_label='I2T') p_t2i, r_t2i = pr_curve(qBY, rBX, query_labels, db_labels, tqdm_label='T2I') p_i2i, r_i2i = pr_curve(qBX, rBX, query_labels, db_labels, tqdm_label='I2I') p_t2t, r_t2t = pr_curve(qBY, rBY, query_labels, db_labels, tqdm_label='T2T') K = [1, 10, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000] pk_i2t = p_top_k(qBX, rBY, query_labels, db_labels, K, tqdm_label='I2T') pk_t2i = p_top_k(qBY, rBX, query_labels, db_labels, K, tqdm_label='T2I') pk_i2i = p_top_k(qBX, rBX, query_labels, db_labels, K, tqdm_label='I2I') pk_t2t = p_top_k(qBY, rBY, query_labels, db_labels, K, tqdm_label='T2T') mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels) mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels) mapi2i = calc_map_k(qBX, rBX, query_labels, db_labels) mapt2t = calc_map_k(qBY, rBY, query_labels, db_labels) pr_dict = {'pi2t': p_i2t.cpu().numpy(), 'ri2t': r_i2t.cpu().numpy(), 'pt2i': p_t2i.cpu().numpy(), 'rt2i': r_t2i.cpu().numpy(), 'pi2i': p_i2i.cpu().numpy(), 'ri2i': r_i2i.cpu().numpy(), 'pt2t': p_t2t.cpu().numpy(), 'rt2t': r_t2t.cpu().numpy()} pk_dict = {'k': K, 'pki2t': pk_i2t.cpu().numpy(), 'pkt2i': pk_t2i.cpu().numpy(), 'pki2i': pk_i2i.cpu().numpy(), 'pkt2t': pk_t2t.cpu().numpy()} map_dict = {'mapi2t': float(mapi2t.cpu().numpy()), 'mapt2i': float(mapt2i.cpu().numpy()), 'mapi2i': float(mapi2i.cpu().numpy()), 'mapt2t': float(mapt2t.cpu().numpy())} print(' Test MAP: MAP(i->t) = {:3.4f}, MAP(t->i) = {:3.4f}, MAP(i->i) = {:3.4f}, MAP(t->t) = {:3.4f}'.format(mapi2t, mapt2i, mapi2i, mapt2t)) path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc) write_pickle(os.path.join(path, 'pr_dict.pkl'), pr_dict) write_pickle(os.path.join(path, 'pk_dict.pkl'), pk_dict) write_pickle(os.path.join(path, 'map_dict.pkl'), map_dict)
def train(**kwargs): since = time.time() opt.parse(kwargs) if (opt.device is None) or (opt.device == 'cpu'): opt.device = torch.device('cpu') else: opt.device = torch.device(opt.device) images, tags, labels = load_data(opt.data_path, type=opt.dataset) train_data = Dataset(opt, images, tags, labels) train_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True) L = train_data.get_labels() L = L.to(opt.device) # test i_query_data = Dataset(opt, images, tags, labels, test='image.query') i_db_data = Dataset(opt, images, tags, labels, test='image.db') t_query_data = Dataset(opt, images, tags, labels, test='text.query') t_db_data = Dataset(opt, images, tags, labels, test='text.db') i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False) i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False) t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False) t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False) query_labels, db_labels = i_query_data.get_labels() query_labels = query_labels.to(opt.device) db_labels = db_labels.to(opt.device) model = CPAH(opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, opt.num_label).to(opt.device) # discriminator = DisModel(opt.hidden_dim, opt.num_label).to(opt.device) optimizer_gen = Adam([ {'params': model.image_module.parameters()}, {'params': model.text_module.parameters()}, {'params': model.hash_module.parameters()}, {'params': model.mask_module.parameters()}, {'params': model.consistency_dis.parameters()}, {'params': model.classifier.parameters()}, ], lr=opt.lr, weight_decay=0.0005) optimizer_dis = Adam(model.feature_dis.parameters(), lr=opt.lr, betas=(0.5, 0.9), weight_decay=0.0001) #tri_loss = TripletLoss(opt, reduction='sum') loss_bce = torch.nn.BCELoss(reduction='sum') loss_ce = torch.nn.CrossEntropyLoss(reduction='sum') loss = [] losses = [] max_mapi2t = 0. max_mapt2i = 0. max_mapi2i = 0. max_mapt2t = 0. max_average = 0. mapt2i_list = [] mapi2t_list = [] mapi2i_list = [] mapt2t_list = [] train_times = [] B = torch.randn(opt.training_size, opt.bit).sign().to(opt.device) H_i = torch.zeros(opt.training_size, opt.bit).to(opt.device) H_t = torch.zeros(opt.training_size, opt.bit).to(opt.device) torch.autograd.set_detect_anomaly(True) for epoch in range(opt.max_epoch): t1 = time.time() e_loss = 0 e_losses = {'adv': 0, 'class': 0, 'quant': 0, 'pairwise': 0} # for i, (ind, img, txt, label) in tqdm(enumerate(train_dataloader)): for i, (ind, img, txt, label) in enumerate(train_dataloader): #print(i) imgs = img.to(opt.device) txt = txt.to(opt.device) labels = label.to(opt.device) batch_size = len(ind) h_img, h_txt, f_rc_img, f_rc_txt, f_rp_img, f_rp_txt = model(imgs, txt) H_i[ind, :] = h_img H_t[ind, :] = h_txt ################################### # train discriminator. CPAH paper: (5) ################################### # IMG - real, TXT - fake # train with real (IMG) optimizer_dis.zero_grad() d_real = model.dis_D(f_rc_img.detach()) d_real = -torch.log(torch.sigmoid(d_real)).mean() d_real.backward() # train with fake (TXT) d_fake = model.dis_D(f_rc_txt.detach()) d_fake = -torch.log(torch.ones(batch_size).to(opt.device) - torch.sigmoid(d_fake)).mean() d_fake.backward() # train with gradient penalty (GP) # interpolate real and fake data alpha = torch.rand(batch_size, opt.hidden_dim).to(opt.device) interpolates = alpha * f_rc_img.detach() + (1 - alpha) * f_rc_txt.detach() interpolates.requires_grad_() disc_interpolates = model.dis_D(interpolates) # get gradients with respect to inputs gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).to(opt.device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) # calculate penalty gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10 # 10 is GP hyperparameter gradient_penalty.backward() optimizer_dis.step() ################################### # train generator ################################### # adversarial loss, CPAH paper: (6) loss_adver = -torch.log(torch.sigmoid(model.dis_D(f_rc_txt))).mean() # don't detach from graph # consistency classification loss, CPAH paper: (7) f_r = torch.vstack([f_rc_img, f_rc_txt, f_rp_img, f_rp_txt]) l_r = [1] * len(ind) * 2 + [0] * len(ind) + [2] * len(ind) # labels l_r = torch.tensor(l_r).to(opt.device) loss_consistency_class = loss_ce(f_r, l_r) # classification loss, CPAH paper: (8) l_f_rc_img = model.dis_classify(f_rc_img, 'img') l_f_rc_txt = model.dis_classify(f_rc_txt, 'txt') loss_class = loss_bce(l_f_rc_img, labels) + loss_bce(l_f_rc_txt, labels) #loss_class = torch.tensor(0).to(opt.device) # pairwise loss, CPAH paper: (10) S = (labels.mm(labels.T) > 0).float() # theta = 0.5 * ((h_img.mm(h_txt.T) + h_txt.mm(h_img.T)) / 2) # not completely sure theta = 0.5 * h_img.mm(h_txt.T) #theta.retain_grad() #theta.register_hook(lambda x: print("theta :", torch.max(x), torch.min(x), torch.mean(x))) e_theta = torch.exp(theta) #e_theta.retain_grad() #e_theta.register_hook(lambda x: print("theta :", torch.max(x), torch.min(x), torch.mean(x))) loss_pairwise = -torch.sum(S*theta - torch.log(1 + e_theta)) # quantization loss, CPAH paper: (11) loss_quant = torch.sum(torch.pow(B[ind, :] - h_img, 2)) + torch.sum(torch.pow(B[ind, :] - h_txt, 2)) #loss_quant = torch.tensor(0).to(opt.device) err = 100 * loss_adver + opt.alpha * (loss_consistency_class + loss_class) + loss_pairwise + opt.beta * loss_quant e_losses['adv'] += 100 * loss_adver.detach().cpu().numpy() e_losses['class'] += (opt.alpha * (loss_consistency_class + loss_class)).detach().cpu().numpy() e_losses['pairwise'] += loss_pairwise.detach().cpu().numpy() e_losses['quant'] += loss_quant.detach().cpu().numpy() optimizer_gen.zero_grad() err.backward() optimizer_gen.step() e_loss = err + e_loss loss.append(e_loss.item()) e_losses['sum'] = sum(e_losses.values()) losses.append(e_losses) B = (0.5 * (H_i.detach() + H_t.detach())).sign() delta_t = time.time() - t1 print('Epoch: {:4d}/{:4d}, time, {:3.3f}s, loss: {:15.3f},'.format(epoch + 1, opt.max_epoch, delta_t, loss[-1]) + 5 * ' ' + 'losses:', e_losses) # validate if opt.valid and (epoch + 1) % opt.valid_freq == 0: mapi2t, mapt2i, mapi2i, mapt2t = valid(model, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader, query_labels, db_labels) print( 'Epoch: {:4d}/{:4d}, validation MAP: MAP(i->t) = {:3.4f}, MAP(t->i) = {:3.4f}, MAP(i->i) = {:3.4f}, MAP(t->t) = {:3.4f}'.format( epoch + 1, opt.max_epoch, mapi2t, mapt2i, mapi2i, mapt2t)) mapi2t_list.append(mapi2t) mapt2i_list.append(mapt2i) mapi2i_list.append(mapi2i) mapt2t_list.append(mapt2t) train_times.append(delta_t) if 0.5 * (mapi2t + mapt2i) > max_average: max_mapi2t = mapi2t max_mapt2i = mapt2i max_mapi2i = mapi2i max_mapt2t = mapt2t max_average = 0.5 * (mapi2t + mapt2i) save_model(model) path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc) with torch.cuda.device(opt.device): torch.save([H_i, H_t], os.path.join(path, 'hash_maps_i_t.pth')) with torch.cuda.device(opt.device): torch.save(B, os.path.join(path, 'code_map.pth')) # decrease the lr to its one fifth every 30 epochs if epoch % 30 == 0: for params in optimizer_gen.param_groups: params['lr'] = max(params['lr'] * 0.2, 1e-6) if epoch % 100 == 0: pass if not opt.valid: save_model(model) time_elapsed = time.time() - since print('\n Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) if opt.valid: print(' Max MAP: MAP(i->t) = {:3.4f}, MAP(t->i) = {:3.4f}, MAP(i->i) = {:3.4f}, MAP(t->t) = {:3.4f}'.format( max_mapi2t, max_mapt2i, max_mapi2i, max_mapt2t)) else: mapi2t, mapt2i, mapi2i, mapt2t = valid(model, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader, query_labels, db_labels) print(' Max MAP: MAP(i->t) = {:3.4f}, MAP(t->i) = {:3.4f}, MAP(i->i) = {:3.4f}, MAP(t->t) = {:3.4f}'.format( mapi2t, mapt2i, mapi2i, mapt2t)) res_dict = {'mapi2t': mapi2t_list, 'mapt2i': mapt2i_list, 'mapi2i': mapi2i_list, 'mapt2t': mapt2t_list, 'epoch_times': train_times, 'losses': losses} path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit) + str(opt.proc) write_pickle(os.path.join(path, 'res_dict.pkl'), res_dict)