def train( train_data, query_data, query_targets, retrieval_data, retrieval_targets, code_length, max_iter, device, topk, ): """ Training model. Args train_data(torch.Tensor): Training data. query_data(torch.Tensor): Query data. query_targets(torch.Tensor): Query targets. retrieval_data(torch.Tensor): Retrieval data. retrieval_targets(torch.Tensor): Retrieval targets. code_length(int): Hash code length. max_iter(int): Number of iterations. device(torch.device): GPU or CPU. topk(int): Calculate top k data points map. Returns checkpoint(dict): Checkpoint. """ # Initialization query_data, query_targets, retrieval_data, retrieval_targets = query_data.to(device), query_targets.to(device), retrieval_data.to(device), retrieval_targets.to(device) R = torch.randn(code_length, code_length).to(device) [U, _, _] = torch.svd(R) R = U[:, :code_length] # PCA pca = PCA(n_components=code_length) V = torch.from_numpy(pca.fit_transform(train_data.numpy())).to(device) # Training for i in range(max_iter): V_tilde = V @ R B = V_tilde.sign() [U, _, VT] = torch.svd(B.t() @ V) R = (VT.t() @ U.t()) # Evaluate # Generate query code and retrieval code query_code = generate_code(query_data.cpu(), code_length, R, pca) retrieval_code = generate_code(retrieval_data.cpu(), code_length, R, pca) # Compute map mAP = mean_average_precision( query_code, retrieval_code, query_targets, retrieval_targets, device, topk, ) # P-R curve P, Recall = pr_curve( query_code, retrieval_code, query_targets, retrieval_targets, device, ) # Save checkpoint checkpoint = { 'qB': query_code, 'rB': retrieval_code, 'qL': query_targets, 'rL': retrieval_targets, 'pca': pca, 'rotation_matrix': R, 'P': P, 'R': Recall, 'map': mAP, } return checkpoint
def train( train_data, train_targets, query_data, query_targets, retrieval_data, retrieval_targets, code_length, num_anchor, max_iter, lamda, nu, sigma, device, topk, ): """ Training model. Args train_data(torch.Tensor): Training data. train_targets(torch.Tensor): Training targets. query_data(torch.Tensor): Query data. query_targets(torch.Tensor): Query targets. retrieval_data(torch.Tensor): Retrieval data. retrieval_targets(torch.Tensor): Retrieval targets. code_length(int): Hash code length. num_anchor(int): Number of anchors. max_iter(int): Number of iterations. lamda, nu, sigma(float): Hyper-parameters. device(torch.device): GPU or CPU. topk(int): Compute mAP using top k retrieval result. Returns checkpoint(dict): Checkpoint. """ # Initialization n = train_data.shape[0] L = code_length m = num_anchor t = max_iter X = train_data.t() Y = train_targets.t() B = torch.randn(L, n).sign() # Permute data perm_index = torch.randperm(n) X = X[:, perm_index] Y = Y[:, perm_index] # Randomly select num_anchor samples from the training data anchor = X[:, :m] # Map training data via RBF kernel phi_x = torch.from_numpy(rbf_kernel(X.numpy().T, anchor.numpy().T, sigma)).t() # Training B = B.to(device) Y = Y.to(device) phi_x = phi_x.to(device) for it in range(t): # G-Step W = torch.pinverse(B @ B.t() + lamda * torch.eye(code_length, device=device)) @ B @ Y.t() # F-Step P = torch.pinverse(phi_x @ phi_x.t()) @ phi_x @ B.t() F_X = P.t() @ phi_x # B-Step B = solve_dcc(B, W, Y, F_X, nu) # Evaluate query_code = generate_code(query_data.t(), anchor, P, sigma) retrieval_code = generate_code(retrieval_data.t(), anchor, P, sigma) # Compute map mAP = mean_average_precision( query_code.t().to(device), retrieval_code.t().to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) # PR curve Precision, R = pr_curve( query_code.t().to(device), retrieval_code.t().to(device), query_targets.to(device), retrieval_targets.to(device), device, ) # Save checkpoint checkpoint = { 'tB': B, 'tL': train_targets, 'qB': query_code, 'qL': query_targets, 'rB': retrieval_code, 'rL': retrieval_targets, 'anchor': anchor, 'projection': P, 'P': Precision, 'R': R, 'map': mAP, } return checkpoint
def train( train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr, max_iter, alpha, topk, evaluate_interval, ): """ Training model. Args train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. arch(str): CNN model name. code_length(int): Hash code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. alpha(float): Hyper-parameters. topk(int): Compute top k map. evaluate_interval(int): Interval of evaluation. Returns checkpoint(dict): Checkpoint. """ # Load model model = load_model(arch, code_length).to(device) # Create criterion, optimizer, scheduler criterion = HashNetLoss(alpha) optimizer = optim.RMSprop( model.parameters(), lr=lr, weight_decay=5e-4, ) scheduler = CosineAnnealingLR( optimizer, max_iter, lr / 100, ) # Initialization running_loss = 0. best_map = 0. training_time = 0. # Training # In this implementation, I do not use "scaled tanh". # It is useless and hard to tune parameters, sometimes it may decrease performance. # Refer to https://github.com/thuml/HashNet/issues/29 for it in range(max_iter): tic = time.time() for data, targets, index in train_dataloader: data, targets, index = data.to(device), targets.to( device), index.to(device) optimizer.zero_grad() # Create similarity matrix S = (targets @ targets.t() > 0).float() outputs = model(data) loss = criterion(outputs, S) running_loss += loss.item() loss.backward() optimizer.step() scheduler.step() training_time += time.time() - tic # Evaluate if it % evaluate_interval == evaluate_interval - 1: # Generate hash code query_code = generate_code(model, query_dataloader, code_length, device) retrieval_code = generate_code(model, retrieval_dataloader, code_length, device) query_targets = query_dataloader.dataset.get_onehot_targets() retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets( ) # Compute map mAP = mean_average_precision( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) # Compute pr curve P, R = pr_curve( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, ) # Log logger.info( '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format( it + 1, max_iter, running_loss / evaluate_interval, mAP, training_time, )) running_loss = 0. # Checkpoint if best_map < mAP: best_map = mAP checkpoint = { 'model': model.state_dict(), 'qB': query_code.cpu(), 'rB': retrieval_code.cpu(), 'qL': query_targets.cpu(), 'rL': retrieval_targets.cpu(), 'P': P, 'R': R, 'map': best_map, } return checkpoint
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr, max_iter, topk, evaluate_interval, anchor_num, proportion ): rho1 = 1e-2 #ρ1 rho2 = 1e-2 #ρ2 rho3 = 1e-3 #µ1 rho4 = 1e-3 #µ2 gamma = 1e-3 #γ with torch.no_grad(): data_mo = torch.tensor([]).to(device) for data, _, _ in train_dataloader: data = data.to(device) data_mo = torch.cat((data_mo, data), 0) torch.cuda.empty_cache() n = data_mo.size(1) Y1 = torch.rand(n, code_length).to(device) Y2 = torch.rand(n, code_length).to(device) B=torch.rand(n,code_length).to(device) # Load model model = load_model(arch, code_length).to(device) # Create criterion, optimizer, scheduler criterion = PrototypicalLoss() optimizer = optim.RMSprop( model.parameters(), lr=lr, weight_decay=5e-4, ) scheduler = CosineAnnealingLR( optimizer, max_iter, lr / 100, ) # Initialization running_loss = 0. best_map = 0. training_time = 0. # Training for it in range(max_iter): # timer tic = time.time() # ADMM use anchors in first step but drop them later ''' with torch.no_grad(): output_mo = torch.tensor([]).to(device) for data, _, _ in train_dataloader: data = data.to(device) output_mo_temp = model_mo(data) output_mo = torch.cat((output_mo, output_mo_temp), 0) torch.cuda.empty_cache() anchor = get_anchor(output_mo, anchor_num, device) # compute anchor ''' with torch.no_grad(): output_mo = torch.tensor([]).to(device) for data, _, _ in train_dataloader: output_B, output_A = model(data) output_mo = torch.cat((output_mo, output_A), 0) torch.cuda.empty_cache() dist = euclidean_dist(output_mo, output_mo) dist = torch.exp(-1 * dist / torch.max(dist)).to(device) A = (2 / (torch.max(dist) - torch.min(dist))) * dist - 1 global_A=A.numpy() Z1 = B + 1 / rho1 * Y1 Z1[Z1 > 1] = 1 Z1[Z1 > -1] = -1 Z2 = B + 1 / rho2 * Y2 norm_B = torch.norm(Z2) Z2 = torch.sqrt(n * code_length) * Z2 / norm_B Y1 = Y1 + gamma * rho1 * (B - Z1) Y2 = Y2 + gamma * rho2 * (B - Z2) global_Z1=Z1.numpy() global_Z2=Z2.numpy() global_Y1 = Y1.numpy() global_Y2 = Y2.numpy() B0 = B.numpy() B= torch.from_numpy(scipy.optimize.fmin_l_bfgs_b(Baim_func, B0)).to(device) # self-supervised deep learning model.train() for data, targets, index in train_dataloader: data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() # output_B for hash code .output_A for result without hash layer output_B, output_A= model(data) loss = criterion(output_B, B) running_loss += loss.item() loss.backward() optimizer.step() scheduler.step() training_time += time.time() - tic # Evaluate if it % evaluate_interval == evaluate_interval - 1: # Generate hash code query_code = generate_code(model, query_dataloader, code_length, device) retrieval_code = generate_code(model, retrieval_dataloader, code_length, device) query_targets = query_dataloader.dataset.get_onehot_targets() retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets() # Compute map mAP = mean_average_precision( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) # Compute pr curve P, R = pr_curve( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, ) # Log logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format( it + 1, max_iter, running_loss / evaluate_interval, mAP, training_time, )) running_loss = 0. # Checkpoint if best_map < mAP: best_map = mAP checkpoint = { 'model': model.state_dict(), 'qB': query_code.cpu(), 'rB': retrieval_code.cpu(), 'qL': query_targets.cpu(), 'rL': retrieval_targets.cpu(), 'P': P, 'R': R, 'map': best_map, } return checkpoint
def train( query_data, query_targets, retrieval_data, retrieval_targets, code_length, device, topk, ): """ Training model Args query_data(torch.Tensor): Query data. query_targets(torch.Tensor): One-hot query targets. retrieval_data(torch.Tensor): Retrieval data. retrieval_targets(torch.Tensor): One-hot retrieval targets. code_length(int): Hash code length. device(torch.device): GPU or CPU. topk(int): Calculate top k data map. Returns checkpoint(dict): Checkpoint. """ # Initialization query_data, retrieval_data, query_targets, retrieval_targets = query_data.to( device), retrieval_data.to(device), query_targets.to( device), retrieval_targets.to(device) # Generate random projection matrix W = torch.randn(query_data.shape[1], code_length).to(device) # Generate query and retrieval code query_code = (query_data @ W).sign() retrieval_code = (retrieval_data @ W).sign() # Compute map mAP = mean_average_precision( query_code, retrieval_code, query_targets, retrieval_targets, device, topk, ) # P-R curve P, R = pr_curve( query_code, retrieval_code, query_targets, retrieval_targets, device, ) # Save checkpoint checkpoint = { 'qB': query_code, 'rB': retrieval_code, 'qL': query_targets, 'rL': retrieval_targets, 'W': W, 'P': P, 'R': R, 'map': mAP, } torch.save(checkpoint, 'checkpoints/code_{}_map_{:.4f}.pt'.format(code_length, mAP)) return checkpoint
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr, max_iter, topk, evaluate_interval, anchor_num, proportion ): #print("using device") #print(torch.cuda.current_device()) #print(torch.cuda.get_device_name(torch.cuda.current_device())) # Load model model = load_model(arch, code_length).to(device) model_mo = load_model_mo(arch).to(device) # Create criterion, optimizer, scheduler criterion = PrototypicalLoss() optimizer = optim.RMSprop( model.parameters(), lr=lr, weight_decay=5e-4, ) scheduler = CosineAnnealingLR( optimizer, max_iter, lr / 100, ) # Initialization running_loss = 0. best_map = 0. training_time = 0. # Training for it in range(max_iter): # timer tic = time.time() # harvest prototypes/anchors#some times killed, try another way with torch.no_grad(): output_mo = torch.tensor([]).to(device) for data, _, _ in train_dataloader: data = data.to(device) output_mo_temp = model_mo(data) output_mo = torch.cat((output_mo, output_mo_temp), 0) torch.cuda.empty_cache() anchor = get_anchor(output_mo, anchor_num, device) # compute anchor # self-supervised deep learning model.train() for data, targets, index in train_dataloader: data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() # output output_B = model(data) output_mo_batch = model_mo(data) # prototypes/anchors based similarity #sample_anchor_distance = torch.sqrt(torch.sum((output_mo_batch[:, None, :] - anchor) ** 2, dim=2)).to(device) #sample_anchor_dist_normalize = F.normalize(sample_anchor_distance, p=2, dim=1).to(device) #S = sample_anchor_dist_normalize @ sample_anchor_dist_normalize.t() # loss #loss = criterion(output_B, S) #running_loss = running_loss + loss.item() #loss.backward(retain_graph=True) with torch.no_grad(): dist = torch.sum((output_mo_batch[:, None, :] - anchor.to(device)) ** 2, dim=2) k = dist.size(1) dist = torch.exp(-1 * dist / torch.max(dist)).to(device) Z_su = torch.ones(k, 1).to(device) Z_sum = torch.sqrt(dist.mm(Z_su)) + 1e-12 Z_simi = torch.div(dist, Z_sum).to(device) S = (Z_simi.mm(Z_simi.t())) S=(2/(torch.max(S)-torch.min(S)))*S-1 loss = criterion(output_B, S) running_loss += loss.item() loss.backward() optimizer.step() with torch.no_grad(): # momentum update: for param_q, param_k in zip(model.parameters(), model_mo.parameters()): param_k.data = param_k.data * proportion + param_q.data * (1. - proportion) # proportion = 0.999 for update scheduler.step() training_time += time.time() - tic # Evaluate if it % evaluate_interval == evaluate_interval - 1: # Generate hash code query_code = generate_code(model, query_dataloader, code_length, device) retrieval_code = generate_code(model, retrieval_dataloader, code_length, device) query_targets = query_dataloader.dataset.get_onehot_targets() retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets() # Compute map mAP = mean_average_precision( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) # Compute pr curve P, R = pr_curve( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, ) # Log logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format( it + 1, max_iter, running_loss / evaluate_interval, mAP, training_time, )) running_loss = 0. # Checkpoint if best_map < mAP: best_map = mAP checkpoint = { 'model': model.state_dict(), 'qB': query_code.cpu(), 'rB': retrieval_code.cpu(), 'qL': query_targets.cpu(), 'rL': retrieval_targets.cpu(), 'P': P, 'R': R, 'map': best_map, } return checkpoint
def train( train_data, query_data, query_targets, retrieval_data, retrieval_targets, code_length, device, topk, ): """ Training model. Args train_data(torch.Tensor): Training data. query_data(torch.Tensor): Query data. query_targets(torch.Tensor): Query targets. retrieval_data(torch.Tensor): Retrieval data. retrieval_targets(torch.Tensor): Retrieval targets. code_length(int): Hash code length. device(torch.device): GPU or CPU. topk(int): Calculate top k data points map. Returns checkpoint(dict): Checkpoint. """ # PCA pca = PCA(n_components=code_length) X = pca.fit_transform(train_data.numpy()) # Fit uniform distribution eps = np.finfo(float).eps mn = X.min(0) - eps mx = X.max(0) + eps # Enumerate eigenfunctions R = mx - mn max_mode = np.ceil((code_length + 1) * R / R.max()).astype(np.int) n_modes = max_mode.sum() - len(max_mode) + 1 modes = np.ones([n_modes, code_length]) m = 0 for i in range(code_length): modes[m + 1:m + max_mode[i], i] = np.arange(1, max_mode[i]) + 1 m = m + max_mode[i] - 1 modes -= 1 omega0 = np.pi / R omegas = modes * omega0.reshape(1, -1).repeat(n_modes, 0) eig_val = -(omegas**2).sum(1) ii = (-eig_val).argsort() modes = modes[ii[1:code_length + 1], :] # Evaluate # Generate query code and retrieval code query_code = generate_code(query_data.cpu(), code_length, pca, mn, R, modes).to(device) retrieval_code = generate_code(retrieval_data.cpu(), code_length, pca, mn, R, modes).to(device) query_targets = query_targets.to(device) retrieval_targets = retrieval_targets.to(device) # Compute map mAP = mean_average_precision( query_code, retrieval_code, query_targets, retrieval_targets, device, topk, ) # P-R curve P, Recall = pr_curve( query_code, retrieval_code, query_targets, retrieval_targets, device, ) # Save checkpoint checkpoint = { 'qB': query_code, 'rB': retrieval_code, 'qL': query_targets, 'rL': retrieval_targets, 'P': P, 'R': Recall, 'map': mAP, } return checkpoint