def train( train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr, max_iter, mu, nu, eta, topk, evaluate_interval, ): """ Training model. Args train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader. arch(str): CNN model name. code_length(int): Hash code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter: int Maximum iteration mu, nu, eta(float): Hyper-parameters. topk(int): Compute mAP using top k retrieval result evaluate_interval(int): Evaluation interval. Returns checkpoint(dict): Checkpoint. """ # Construct network, optimizer, loss model = load_model(arch, code_length).to(device) criterion = DSDHLoss(eta) optimizer = optim.RMSprop( model.parameters(), lr=lr, weight_decay=1e-5, ) scheduler = CosineAnnealingLR(optimizer, max_iter, 1e-7) # Initialize N = len(train_dataloader.dataset) B = torch.randn(code_length, N).sign().to(device) U = torch.zeros(code_length, N).to(device) train_targets = train_dataloader.dataset.get_onehot_targets().to(device) S = (train_targets @ train_targets.t() > 0).float() Y = train_targets.t() best_map = 0. iter_time = time.time() for it in range(max_iter): model.train() # CNN-step for data, targets, index in train_dataloader: data, targets = data.to(device), targets.to(device) optimizer.zero_grad() U_batch = model(data).t() U[:, index] = U_batch.data loss = criterion(U_batch, U, S[:, index], B[:, index]) loss.backward() optimizer.step() scheduler.step() # W-step W = torch.inverse(B @ B.t() + nu / mu * torch.eye(code_length, device=device)) @ B @ Y.t() # B-step B = solve_dcc(W, Y, U, B, eta, mu) # Evaluate if it % evaluate_interval == evaluate_interval - 1: iter_time = time.time() - iter_time epoch_loss = calc_loss(U, S, Y, W, B, mu, nu, eta) # Generate hash code query_code = generate_code(model, query_dataloader, code_length, device) retrieval_code = generate_code(model, retrieval_dataloader, code_length, device) query_targets = query_dataloader.dataset.get_onehot_targets() retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets( ) # Compute map mAP = mean_average_precision( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) logger.info( '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format( it + 1, max_iter, epoch_loss, mAP, iter_time)) # Save checkpoint if best_map < mAP: best_map = mAP checkpoint = { 'qB': query_code, 'qL': query_targets, 'rB': retrieval_code, 'rL': retrieval_targets, 'model': model.state_dict(), 'map': best_map, } iter_time = time.time() return checkpoint
def train( train_data, train_targets, query_data, query_targets, retrieval_data, retrieval_targets, code_length, num_anchor, max_iter, lamda, nu, sigma, device, topk, ): """ Training model. Args train_data(torch.Tensor): Training data. train_targets(torch.Tensor): Training targets. query_data(torch.Tensor): Query data. query_targets(torch.Tensor): Query targets. retrieval_data(torch.Tensor): Retrieval data. retrieval_targets(torch.Tensor): Retrieval targets. code_length(int): Hash code length. num_anchor(int): Number of anchors. max_iter(int): Number of iterations. lamda, nu, sigma(float): Hyper-parameters. device(torch.device): GPU or CPU. topk(int): Compute mAP using top k retrieval result. Returns checkpoint(dict): Checkpoint. """ # Initialization n = train_data.shape[0] L = code_length m = num_anchor t = max_iter X = train_data.t() Y = train_targets.t() B = torch.randn(L, n).sign() # Permute data perm_index = torch.randperm(n) X = X[:, perm_index] Y = Y[:, perm_index] # Randomly select num_anchor samples from the training data anchor = X[:, :m] # Map training data via RBF kernel phi_x = torch.from_numpy(rbf_kernel(X.numpy().T, anchor.numpy().T, sigma)).t() # Training B = B.to(device) Y = Y.to(device) phi_x = phi_x.to(device) for it in range(t): # G-Step W = torch.pinverse(B @ B.t() + lamda * torch.eye(code_length, device=device)) @ B @ Y.t() # F-Step P = torch.pinverse(phi_x @ phi_x.t()) @ phi_x @ B.t() F_X = P.t() @ phi_x # B-Step B = solve_dcc(B, W, Y, F_X, nu) # Evaluate query_code = generate_code(query_data.t(), anchor, P, sigma) retrieval_code = generate_code(retrieval_data.t(), anchor, P, sigma) # Compute map mAP = mean_average_precision( query_code.t().to(device), retrieval_code.t().to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) # PR curve Precision, R = pr_curve( query_code.t().to(device), retrieval_code.t().to(device), query_targets.to(device), retrieval_targets.to(device), device, ) # Save checkpoint checkpoint = { 'tB': B, 'tL': train_targets, 'qB': query_code, 'qL': query_targets, 'rB': retrieval_code, 'rL': retrieval_targets, 'anchor': anchor, 'projection': P, 'P': Precision, 'R': R, 'map': mAP, } return checkpoint
def train( query_dataloader, retrieval_dataloader, code_length, args, # args.device, # lr, # args.max_iter, # args.max_epoch, # args.num_samples, # args.batch_size, # args.root, # dataset, # args.gamma, # topk, ): """ Training model. Args query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. args.device(torch.args.device): GPU or CPU. lr(float): Learning rate. args.max_iter(int): Number of iterations. args.max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. args.batch_size(int): Batch size. args.root(str): Path of dataset. dataset(str): Dataset name. args.gamma(float): Hyper-parameters. topk(int): Topk k map. Returns mAP(float): Mean Average Precision. """ # Initialization # model = alexnet.load_model(code_length).to(args.device) model = resnet.resnet50(pretrained=args.pretrain, num_classes=code_length).to(args.device) if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momen, nesterov=args.nesterov) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_step) criterion = ADSH_Loss(code_length, args.gamma) num_retrieval = len(retrieval_dataloader.dataset) U = torch.zeros(args.num_samples, code_length).to(args.device) B = torch.randn(num_retrieval, code_length).to(args.device) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to( args.device) cnn_losses, hash_losses, quan_losses = AverageMeter(), AverageMeter( ), AverageMeter() start = time.time() best_mAP = 0 for it in range(args.max_iter): iter_start = time.time() # Sample training data for cnn learning train_dataloader, sample_index = sample_dataloader( retrieval_dataloader, args.num_samples, args.batch_size, args.root, args.dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to( args.device) S = (train_targets @ retrieval_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(args.max_epoch): cnn_losses.reset() hash_losses.reset() quan_losses.reset() for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(args.device), targets.to( args.device), index.to(args.device) optimizer.zero_grad() F = model(data) U[index, :] = F.data cnn_loss, hash_loss, quan_loss = criterion( F, B, S[index, :], sample_index[index]) cnn_losses.update(cnn_loss.item()) hash_losses.update(hash_loss.item()) quan_losses.update(quan_loss.item()) cnn_loss.backward() optimizer.step() logger.info( '[epoch:{}/{}][cnn_loss:{:.6f}][hash_loss:{:.6f}][quan_loss:{:.6f}]' .format(epoch + 1, args.max_epoch, cnn_losses.avg, hash_losses.avg, quan_losses.avg)) scheduler.step() # Update B expand_U = torch.zeros(B.shape).to(args.device) expand_U[sample_index, :] = U B = solve_dcc(B, U, expand_U, S, code_length, args.gamma) # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, args.gamma) # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, args.max_iter, iter_loss, time.time()-iter_start)) logger.info('[iter:{}/{}][loss:{:.6f}][iter_time:{:.2f}]'.format( it + 1, args.max_iter, iter_loss, time.time() - iter_start)) # Evaluate if (it + 1) % 1 == 0: query_code = generate_code(model, query_dataloader, code_length, args.device) mAP = evaluate.mean_average_precision( query_code.to(args.device), B, query_dataloader.dataset.get_onehot_targets().to(args.device), retrieval_targets, args.device, args.topk, ) if mAP > best_mAP: best_mAP = mAP # Save checkpoints ret_path = os.path.join('checkpoints', args.info, str(code_length)) # ret_path = 'checkpoints/' + args.info if not os.path.exists(ret_path): os.makedirs(ret_path) torch.save(query_code.cpu(), os.path.join(ret_path, 'query_code.t')) torch.save(B.cpu(), os.path.join(ret_path, 'database_code.t')) torch.save(query_dataloader.dataset.get_onehot_targets, os.path.join(ret_path, 'query_targets.t')) torch.save(retrieval_targets.cpu(), os.path.join(ret_path, 'database_targets.t')) torch.save(model.cpu(), os.path.join(ret_path, 'model.t')) model = model.to(args.device) logger.info( '[iter:{}/{}][code_length:{}][mAP:{:.5f}][best_mAP:{:.5f}]'. format(it + 1, args.max_iter, code_length, mAP, best_mAP)) logger.info('[Training time:{:.2f}]'.format(time.time() - start)) return best_mAP
def train( query_dataloader, retrieval_dataloader, code_length, device, lr, max_iter, max_epoch, num_samples, batch_size, root, dataset, parameters, topk, ): """ Training model. Args query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. batch_size(int): Batch size. root(str): Path of dataset. dataset(str): Dataset name. gamma(float): Hyper-parameters. topk(int): Topk k map. Returns mAP(float): Mean Average Precision. """ # parameters = {'eta':2, 'mu':0.2, 'gamma':1, 'varphi':200} eta = parameters['eta'] mu = parameters['mu'] gamma = parameters['gamma'] varphi = parameters['varphi'] # Initialization model = alexnet.load_model(code_length).to(device) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) criterion = DADSH2_Loss(code_length, eta, mu, gamma, varphi, device) num_retrieval = len(retrieval_dataloader.dataset) U = torch.zeros(num_samples, code_length).to(device) # U (m*l, l:code_length) B = torch.randn(num_retrieval, code_length).to(device) # V (n*l, l:code_length) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to( device) # Y2 (n*c, c:classes) start = time.time() for it in range(max_iter): iter_start = time.time() # Sample training data for cnn learning train_dataloader, sample_index = sample_dataloader( retrieval_dataloader, num_samples, batch_size, root, dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to( device) # Y1 (m*c, c:classes) S = (train_targets @ retrieval_targets.t() > 0).float() # S (m*n, c:classes) S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(max_epoch): for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(device), targets.to( device), index.to(device) optimizer.zero_grad() F = model(data) # output (m1*l) U[index, :] = F.data cnn_loss = criterion(F, B, S[index, :], sample_index[index]) cnn_loss.backward() optimizer.step() # Update B expand_U = torch.zeros(B.shape).to(device) expand_U[sample_index, :] = U B = solve_dcc(B, U, expand_U, S, code_length, varphi) # Total loss # iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma) # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, max_iter, iter_loss, time.time()-iter_start)) logger.info('[Training time:{:.2f}]'.format(time.time() - start)) # Evaluate query_code = generate_code(model, query_dataloader, code_length, device) mAP = evaluate.mean_average_precision( query_code.to(device), B, query_dataloader.dataset.get_onehot_targets().to(device), retrieval_targets, device, topk, ) # Save checkpoints torch.save(query_code.cpu(), os.path.join('checkpoints', 'query_code.t')) torch.save(B.cpu(), os.path.join('checkpoints', 'database_code.t')) torch.save(query_dataloader.dataset.get_onehot_targets, os.path.join('checkpoints', 'query_targets.t')) torch.save(retrieval_targets.cpu(), os.path.join('checkpoints', 'database_targets.t')) torch.save(model.cpu(), os.path.join('checkpoints', 'model.t')) return mAP
def train( train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, eta, lr, max_iter, topk, evaluate_interval, ): """ Training model. Args train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. arch(str): CNN model name. code_length(int): Hash code length. device(torch.device): GPU or CPU. eta(float): Hyper-parameter. lr(float): Learning rate. max_iter(int): Number of iterations. topk(int): Calculate map of top k. evaluate_interval(int): Evaluation interval. Returns checkpoint(dict): Checkpoint. """ # Create model, optimizer, criterion, scheduler model = load_model(arch, code_length).to(device) criterion = DPSHLoss(eta) optimizer = optim.RMSprop( model.parameters(), lr=lr, weight_decay=1e-5, ) scheduler = CosineAnnealingLR(optimizer, max_iter, 1e-7) # Initialization N = len(train_dataloader.dataset) U = torch.zeros(N, code_length).to(device) train_targets = train_dataloader.dataset.get_onehot_targets().to(device) # Training best_map = 0.0 iter_time = time.time() for it in range(max_iter): model.train() running_loss = 0. for data, targets, index in train_dataloader: data, targets = data.to(device), targets.to(device) optimizer.zero_grad() S = (targets @ train_targets.t() > 0).float() U_cnn = model(data) U[index, :] = U_cnn.data loss = criterion(U_cnn, U, S) loss.backward() optimizer.step() running_loss += loss.item() scheduler.step() # Evaluate if it % evaluate_interval == evaluate_interval - 1: iter_time = time.time() - iter_time # Generate hash code and one-hot targets query_code = generate_code(model, query_dataloader, code_length, device) query_targets = query_dataloader.dataset.get_onehot_targets() retrieval_code = generate_code(model, retrieval_dataloader, code_length, device) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets( ) # Compute map mAP = mean_average_precision( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) # Save checkpoint if best_map < mAP: best_map = mAP checkpoint = { 'qB': query_code, 'qL': query_targets, 'rB': retrieval_code, 'rL': retrieval_targets, 'model': model.state_dict(), 'map': best_map, } logger.info( '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format( it + 1, max_iter, running_loss, mAP, iter_time, )) iter_time = time.time() return checkpoint
def train( train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr, max_iter, alpha, topk, evaluate_interval, ): """ Training model. Args train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. arch(str): CNN model name. code_length(int): Hash code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. alpha(float): Hyper-parameters. topk(int): Compute top k map. evaluate_interval(int): Interval of evaluation. Returns checkpoint(dict): Checkpoint. """ # Load model model = load_model(arch, code_length).to(device) # Create criterion, optimizer, scheduler criterion = HashNetLoss(alpha) optimizer = optim.RMSprop( model.parameters(), lr=lr, weight_decay=5e-4, ) scheduler = CosineAnnealingLR( optimizer, max_iter, lr / 100, ) # Initialization running_loss = 0. best_map = 0. training_time = 0. # Training # In this implementation, I do not use "scaled tanh". # It is useless and hard to tune parameters, sometimes it may decrease performance. # Refer to https://github.com/thuml/HashNet/issues/29 for it in range(max_iter): tic = time.time() for data, targets, index in train_dataloader: data, targets, index = data.to(device), targets.to( device), index.to(device) optimizer.zero_grad() # Create similarity matrix S = (targets @ targets.t() > 0).float() outputs = model(data) loss = criterion(outputs, S) running_loss += loss.item() loss.backward() optimizer.step() scheduler.step() training_time += time.time() - tic # Evaluate if it % evaluate_interval == evaluate_interval - 1: # Generate hash code query_code = generate_code(model, query_dataloader, code_length, device) retrieval_code = generate_code(model, retrieval_dataloader, code_length, device) query_targets = query_dataloader.dataset.get_onehot_targets() retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets( ) # Compute map mAP = mean_average_precision( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) # Compute pr curve P, R = pr_curve( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, ) # Log logger.info( '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format( it + 1, max_iter, running_loss / evaluate_interval, mAP, training_time, )) running_loss = 0. # Checkpoint if best_map < mAP: best_map = mAP checkpoint = { 'model': model.state_dict(), 'qB': query_code.cpu(), 'rB': retrieval_code.cpu(), 'qL': query_targets.cpu(), 'rL': retrieval_targets.cpu(), 'P': P, 'R': R, 'map': best_map, } return checkpoint
def train( train_data, query_data, query_targets, retrieval_data, retrieval_targets, code_length, max_iter, device, topk, ): """ Training model. Args train_data(torch.Tensor): Training data. query_data(torch.Tensor): Query data. query_targets(torch.Tensor): Query targets. retrieval_data(torch.Tensor): Retrieval data. retrieval_targets(torch.Tensor): Retrieval targets. code_length(int): Hash code length. max_iter(int): Number of iterations. device(torch.device): GPU or CPU. topk(int): Calculate top k data points map. Returns checkpoint(dict): Checkpoint. """ # Initialization query_data, query_targets, retrieval_data, retrieval_targets = query_data.to(device), query_targets.to(device), retrieval_data.to(device), retrieval_targets.to(device) R = torch.randn(code_length, code_length).to(device) [U, _, _] = torch.svd(R) R = U[:, :code_length] # PCA pca = PCA(n_components=code_length) V = torch.from_numpy(pca.fit_transform(train_data.numpy())).to(device) # Training for i in range(max_iter): V_tilde = V @ R B = V_tilde.sign() [U, _, VT] = torch.svd(B.t() @ V) R = (VT.t() @ U.t()) # Evaluate # Generate query code and retrieval code query_code = generate_code(query_data.cpu(), code_length, R, pca) retrieval_code = generate_code(retrieval_data.cpu(), code_length, R, pca) # Compute map mAP = mean_average_precision( query_code, retrieval_code, query_targets, retrieval_targets, device, topk, ) # P-R curve P, Recall = pr_curve( query_code, retrieval_code, query_targets, retrieval_targets, device, ) # Save checkpoint checkpoint = { 'qB': query_code, 'rB': retrieval_code, 'qL': query_targets, 'rL': retrieval_targets, 'pca': pca, 'rotation_matrix': R, 'P': P, 'R': Recall, 'map': mAP, } return checkpoint
def train( query_dataloader, train_dataloader, retrieval_dataloader, code_length, args # device, # lr, # args.max_iter, # args.max_epoch, # args.num_samples, # args.batch_size, # args.root, # dataset, # args.gamma, # args.topk, ): """ Training model. Args query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. device(torch.device): GPU or CPU. lr(float): Learning rate. args.max_iter(int): Number of iterations. args.max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. args.batch_size(int): Batch size. args.root(str): Path of dataset. dataset(str): Dataset name. args.gamma(float): Hyper-parameters. args.topk(int): args.Topk k map. Returns mAP(float): Mean Average Precision. """ # Initialization # model = alexnet.load_model(code_length).to(device) # model = resnet.resnet50(pretrained=True, num_classes=code_length).to(device) num_classes, att_size, feat_size = args.num_classes, 4, 2048 model = exchnet.exchnet(code_length=code_length, num_classes=num_classes, att_size=att_size, feat_size=feat_size, device=args.device, pretrained=args.pretrain).to(args.device) if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momen, nesterov=args.nesterov) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_step) criterion = Exch_Loss(code_length, args.device, lambd_sp=1.0, lambd_ch=1.0) criterion.quanting = args.quan_loss num_retrieval = len(retrieval_dataloader.dataset) U = torch.zeros(args.num_samples, code_length).to(args.device) B = torch.randn(num_retrieval, code_length).to(args.device) # B = torch.zeros(num_retrieval, code_length).to(args.device) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to( args.device) C = torch.zeros((num_classes, att_size, feat_size)).to(args.device) start = time.time() best_mAP = 0 for it in range(args.max_iter): iter_start = time.time() # Sample training data for cnn learning train_dataloader, sample_index = sample_dataloader( retrieval_dataloader, args.num_samples, args.batch_size, args.root, args.dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to( args.device) S = (train_targets @ retrieval_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r cnn_losses, hash_losses, quan_losses, sp_losses, ch_losses, align_losses = AverageMeter(), \ AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() # Training CNN model for epoch in range(args.max_epoch): cnn_losses.reset() hash_losses.reset() quan_losses.reset() sp_losses.reset() ch_losses.reset() align_losses.reset() for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(args.device), targets.to( args.device), index.to(args.device) optimizer.zero_grad() F, sp_v, ch_v, avg_local_f = model(data, targets) U[index, :] = F.data batch_anchor_local_f = C[torch.argmax(targets, dim=1)] # print(index) cnn_loss, hash_loss, quan_loss, sp_loss, ch_loss, align_loss = criterion( F, B, S[index, :], sample_index[index], sp_v, ch_v, avg_local_f, batch_anchor_local_f) cnn_losses.update(cnn_loss.item()) hash_losses.update(hash_loss.item()) quan_losses.update(quan_loss.item()) sp_losses.update(sp_loss.item()) ch_losses.update(ch_loss.item()) align_losses.update(align_loss.item()) # print(ch_v) cnn_loss.backward() optimizer.step() logger.info( '[epoch:{}/{}][cnn_loss:{:.6f}][h_loss:{:.6f}][q_loss:{:.6f}][s_loss:{:.4f}][c_loss:{:.4f}][a_loss:{:.4f}]' .format(epoch + 1, args.max_epoch, cnn_losses.avg, hash_losses.avg, quan_losses.avg, sp_losses.avg, ch_losses.avg, align_losses.avg)) scheduler.step() # Update B expand_U = torch.zeros(B.shape).to(args.device) expand_U[sample_index, :] = U if args.quan_loss: B = solve_dcc_adsh(B, U, expand_U, S, code_length, args.gamma) else: B = solve_dcc_exch(B, U, expand_U, S, code_length, args.gamma) # Update C if (it + 1) >= args.align_step: model.exchanging = True criterion.aligning = True model.eval() with torch.no_grad(): C = torch.zeros( (num_classes, att_size, feat_size)).to(args.device) feat_cnt = torch.zeros((num_classes, 1, 1)).to(args.device) for batch, (data, targets, index) in enumerate(retrieval_dataloader): data, targets, index = data.to(args.device), targets.to( args.device), index.to(args.device) _, _, _, avg_local_f = model(data, targets) class_idx = targets.argmax(dim=1) for i in range(targets.shape[0]): C[class_idx[i]] += avg_local_f[i] feat_cnt[class_idx[i]] += 1 C /= feat_cnt model.anchor_local_f = C model.train() # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, args.gamma) # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, args.max_iter, iter_loss, time.time()-iter_start)) logger.info('[iter:{}/{}][loss:{:.6f}][iter_time:{:.2f}]'.format( it + 1, args.max_iter, iter_loss, time.time() - iter_start)) # Evaluate if (it + 1) % 1 == 0: query_code = generate_code(model, query_dataloader, code_length, args.device) mAP = evaluate.mean_average_precision( query_code.to(args.device), B, query_dataloader.dataset.get_onehot_targets().to(args.device), retrieval_targets, args.device, args.topk, ) if mAP > best_mAP: best_mAP = mAP # Save checkpoints ret_path = os.path.join('checkpoints', args.info, str(code_length)) if not os.path.exists(ret_path): os.makedirs(ret_path) torch.save(query_code.cpu(), os.path.join(ret_path, 'query_code.t')) torch.save(B.cpu(), os.path.join(ret_path, 'database_code.t')) torch.save(query_dataloader.dataset.get_onehot_targets, os.path.join(ret_path, 'query_targets.t')) torch.save(retrieval_targets.cpu(), os.path.join(ret_path, 'database_targets.t')) torch.save(model.cpu(), os.path.join(ret_path, 'model.t')) model = model.to(args.device) logger.info( '[iter:{}/{}][code_length:{}][mAP:{:.5f}][best_mAP:{:.5f}]'. format(it + 1, args.max_iter, code_length, mAP, best_mAP)) logger.info('[Training time:{:.2f}]'.format(time.time() - start)) return best_mAP
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr, max_iter, topk, evaluate_interval, anchor_num, proportion ): rho1 = 1e-2 #ρ1 rho2 = 1e-2 #ρ2 rho3 = 1e-3 #µ1 rho4 = 1e-3 #µ2 gamma = 1e-3 #γ with torch.no_grad(): data_mo = torch.tensor([]).to(device) for data, _, _ in train_dataloader: data = data.to(device) data_mo = torch.cat((data_mo, data), 0) torch.cuda.empty_cache() n = data_mo.size(1) Y1 = torch.rand(n, code_length).to(device) Y2 = torch.rand(n, code_length).to(device) B=torch.rand(n,code_length).to(device) # Load model model = load_model(arch, code_length).to(device) # Create criterion, optimizer, scheduler criterion = PrototypicalLoss() optimizer = optim.RMSprop( model.parameters(), lr=lr, weight_decay=5e-4, ) scheduler = CosineAnnealingLR( optimizer, max_iter, lr / 100, ) # Initialization running_loss = 0. best_map = 0. training_time = 0. # Training for it in range(max_iter): # timer tic = time.time() # ADMM use anchors in first step but drop them later ''' with torch.no_grad(): output_mo = torch.tensor([]).to(device) for data, _, _ in train_dataloader: data = data.to(device) output_mo_temp = model_mo(data) output_mo = torch.cat((output_mo, output_mo_temp), 0) torch.cuda.empty_cache() anchor = get_anchor(output_mo, anchor_num, device) # compute anchor ''' with torch.no_grad(): output_mo = torch.tensor([]).to(device) for data, _, _ in train_dataloader: output_B, output_A = model(data) output_mo = torch.cat((output_mo, output_A), 0) torch.cuda.empty_cache() dist = euclidean_dist(output_mo, output_mo) dist = torch.exp(-1 * dist / torch.max(dist)).to(device) A = (2 / (torch.max(dist) - torch.min(dist))) * dist - 1 global_A=A.numpy() Z1 = B + 1 / rho1 * Y1 Z1[Z1 > 1] = 1 Z1[Z1 > -1] = -1 Z2 = B + 1 / rho2 * Y2 norm_B = torch.norm(Z2) Z2 = torch.sqrt(n * code_length) * Z2 / norm_B Y1 = Y1 + gamma * rho1 * (B - Z1) Y2 = Y2 + gamma * rho2 * (B - Z2) global_Z1=Z1.numpy() global_Z2=Z2.numpy() global_Y1 = Y1.numpy() global_Y2 = Y2.numpy() B0 = B.numpy() B= torch.from_numpy(scipy.optimize.fmin_l_bfgs_b(Baim_func, B0)).to(device) # self-supervised deep learning model.train() for data, targets, index in train_dataloader: data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() # output_B for hash code .output_A for result without hash layer output_B, output_A= model(data) loss = criterion(output_B, B) running_loss += loss.item() loss.backward() optimizer.step() scheduler.step() training_time += time.time() - tic # Evaluate if it % evaluate_interval == evaluate_interval - 1: # Generate hash code query_code = generate_code(model, query_dataloader, code_length, device) retrieval_code = generate_code(model, retrieval_dataloader, code_length, device) query_targets = query_dataloader.dataset.get_onehot_targets() retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets() # Compute map mAP = mean_average_precision( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) # Compute pr curve P, R = pr_curve( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, ) # Log logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format( it + 1, max_iter, running_loss / evaluate_interval, mAP, training_time, )) running_loss = 0. # Checkpoint if best_map < mAP: best_map = mAP checkpoint = { 'model': model.state_dict(), 'qB': query_code.cpu(), 'rB': retrieval_code.cpu(), 'qL': query_targets.cpu(), 'rL': retrieval_targets.cpu(), 'P': P, 'R': R, 'map': best_map, } return checkpoint
def increment( query_dataloader, unseen_dataloader, retrieval_dataloader, old_B, code_length, device, lr, max_iter, max_epoch, num_samples, batch_size, root, dataset, gamma, mu, topk, ): """ Increment model. Args query_dataloader, unseen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. old_B(torch.Tensor): Old binary hash code. code_length(int): Hash code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. batch_size(int): Batch size. root(str): Path of dataset. dataset(str): Dataset name. gamma, mu(float): Hyper-parameters. topk(int): Top k map. Returns mAP(float): Mean Average Precision. """ # Initialization model = alexnet.load_model(code_length) model.to(device) model.train() optimizer = optim.Adam( model.parameters(), lr=lr, weight_decay=1e-5, ) criterion = DIHN_Loss(code_length, gamma, mu) lr_scheduler = ExponentialLR(optimizer, 0.91) num_unseen = len(unseen_dataloader.dataset) num_seen = len(old_B) U = torch.zeros(num_samples, code_length).to(device) old_B = old_B.to(device) new_B = torch.randn(num_unseen, code_length).sign().to(device) B = torch.cat((old_B, new_B), dim=0).to(device) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(device) total_time = time.time() for it in range(max_iter): iter_time = time.time() lr_scheduler.step() # Sample training data for cnn learning train_dataloader, sample_index, unseen_sample_in_unseen_index, unseen_sample_in_sample_index = sample_dataloader(retrieval_dataloader, num_samples, batch_size, root, dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to(device) S = (train_targets @ retrieval_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(max_epoch): for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() F = model(data) U[index, :] = F.data cnn_loss = criterion(F, B, S[index, :], index) cnn_loss.backward() optimizer.step() # Update B expand_U = torch.zeros(num_unseen, code_length).to(device) expand_U[unseen_sample_in_unseen_index, :] = U[unseen_sample_in_sample_index, :] new_B = solve_dcc(new_B, U, expand_U, S[:, unseen_dataloader.dataset.UNSEEN_INDEX], code_length, gamma) B = torch.cat((old_B, new_B), dim=0).to(device) # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma, mu) logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time)) logger.info('[DIHN time:{:.2f}]'.format(time.time() - total_time)) # Evaluate query_code = generate_code(model, query_dataloader, code_length, device) mAP = evaluate.mean_average_precision( query_code.to(device), B, query_dataloader.dataset.get_onehot_targets().to(device), retrieval_targets, device, topk, ) return mAP
def train( query_data, query_targets, retrieval_data, retrieval_targets, code_length, device, topk, ): """ Training model Args query_data(torch.Tensor): Query data. query_targets(torch.Tensor): One-hot query targets. retrieval_data(torch.Tensor): Retrieval data. retrieval_targets(torch.Tensor): One-hot retrieval targets. code_length(int): Hash code length. device(torch.device): GPU or CPU. topk(int): Calculate top k data map. Returns checkpoint(dict): Checkpoint. """ # Initialization query_data, retrieval_data, query_targets, retrieval_targets = query_data.to( device), retrieval_data.to(device), query_targets.to( device), retrieval_targets.to(device) # Generate random projection matrix W = torch.randn(query_data.shape[1], code_length).to(device) # Generate query and retrieval code query_code = (query_data @ W).sign() retrieval_code = (retrieval_data @ W).sign() # Compute map mAP = mean_average_precision( query_code, retrieval_code, query_targets, retrieval_targets, device, topk, ) # P-R curve P, R = pr_curve( query_code, retrieval_code, query_targets, retrieval_targets, device, ) # Save checkpoint checkpoint = { 'qB': query_code, 'rB': retrieval_code, 'qL': query_targets, 'rL': retrieval_targets, 'W': W, 'P': P, 'R': R, 'map': mAP, } torch.save(checkpoint, 'checkpoints/code_{}_map_{:.4f}.pt'.format(code_length, mAP)) return checkpoint
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr, max_iter, topk, evaluate_interval, anchor_num, proportion ): #print("using device") #print(torch.cuda.current_device()) #print(torch.cuda.get_device_name(torch.cuda.current_device())) # Load model model = load_model(arch, code_length).to(device) model_mo = load_model_mo(arch).to(device) # Create criterion, optimizer, scheduler criterion = PrototypicalLoss() optimizer = optim.RMSprop( model.parameters(), lr=lr, weight_decay=5e-4, ) scheduler = CosineAnnealingLR( optimizer, max_iter, lr / 100, ) # Initialization running_loss = 0. best_map = 0. training_time = 0. # Training for it in range(max_iter): # timer tic = time.time() # harvest prototypes/anchors#some times killed, try another way with torch.no_grad(): output_mo = torch.tensor([]).to(device) for data, _, _ in train_dataloader: data = data.to(device) output_mo_temp = model_mo(data) output_mo = torch.cat((output_mo, output_mo_temp), 0) torch.cuda.empty_cache() anchor = get_anchor(output_mo, anchor_num, device) # compute anchor # self-supervised deep learning model.train() for data, targets, index in train_dataloader: data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() # output output_B = model(data) output_mo_batch = model_mo(data) # prototypes/anchors based similarity #sample_anchor_distance = torch.sqrt(torch.sum((output_mo_batch[:, None, :] - anchor) ** 2, dim=2)).to(device) #sample_anchor_dist_normalize = F.normalize(sample_anchor_distance, p=2, dim=1).to(device) #S = sample_anchor_dist_normalize @ sample_anchor_dist_normalize.t() # loss #loss = criterion(output_B, S) #running_loss = running_loss + loss.item() #loss.backward(retain_graph=True) with torch.no_grad(): dist = torch.sum((output_mo_batch[:, None, :] - anchor.to(device)) ** 2, dim=2) k = dist.size(1) dist = torch.exp(-1 * dist / torch.max(dist)).to(device) Z_su = torch.ones(k, 1).to(device) Z_sum = torch.sqrt(dist.mm(Z_su)) + 1e-12 Z_simi = torch.div(dist, Z_sum).to(device) S = (Z_simi.mm(Z_simi.t())) S=(2/(torch.max(S)-torch.min(S)))*S-1 loss = criterion(output_B, S) running_loss += loss.item() loss.backward() optimizer.step() with torch.no_grad(): # momentum update: for param_q, param_k in zip(model.parameters(), model_mo.parameters()): param_k.data = param_k.data * proportion + param_q.data * (1. - proportion) # proportion = 0.999 for update scheduler.step() training_time += time.time() - tic # Evaluate if it % evaluate_interval == evaluate_interval - 1: # Generate hash code query_code = generate_code(model, query_dataloader, code_length, device) retrieval_code = generate_code(model, retrieval_dataloader, code_length, device) query_targets = query_dataloader.dataset.get_onehot_targets() retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets() # Compute map mAP = mean_average_precision( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, topk, ) # Compute pr curve P, R = pr_curve( query_code.to(device), retrieval_code.to(device), query_targets.to(device), retrieval_targets.to(device), device, ) # Log logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format( it + 1, max_iter, running_loss / evaluate_interval, mAP, training_time, )) running_loss = 0. # Checkpoint if best_map < mAP: best_map = mAP checkpoint = { 'model': model.state_dict(), 'qB': query_code.cpu(), 'rB': retrieval_code.cpu(), 'qL': query_targets.cpu(), 'rL': retrieval_targets.cpu(), 'P': P, 'R': R, 'map': best_map, } return checkpoint
def train( query_dataloader, seen_dataloader, retrieval_dataloader, code_length, device, lr, max_iter, max_epoch, num_samples, batch_size, root, dataset, gamma, topk, ): """ Training model. Args query_dataloader, seen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. max_epoch(int): Number of epochs. num_samples(int): Number of sampling training data points. batch_size(int): Batch size. root(str): Path of dataset. dataset(str): Dataset name. gamma(float): Hyper-parameters. topk(int): Topk k map. Returns None """ # Initialization model = alexnet.load_model(code_length).to(device) optimizer = optim.Adam( model.parameters(), lr=lr, weight_decay=1e-5, ) criterion = ADSH_Loss(code_length, gamma) lr_scheduler = ExponentialLR(optimizer, 0.9) num_seen = len(seen_dataloader.dataset) U = torch.zeros(num_samples, code_length).to(device) B = torch.randn(num_seen, code_length).sign().to(device) seen_targets = seen_dataloader.dataset.get_onehot_targets().to(device) total_time = time.time() for it in range(max_iter): iter_time = time.time() lr_scheduler.step() # Sample training data for cnn learning train_dataloader, sample_index, _, _ = sample_dataloader(seen_dataloader, num_samples, batch_size, root, dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to(device) S = (train_targets @ seen_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(max_epoch): for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() F = model(data) U[index, :] = F.data cnn_loss = criterion(F, B, S[index, :], index) cnn_loss.backward() optimizer.step() # Update B expand_U = torch.zeros(B.shape).to(device) expand_U[sample_index, :] = U B = solve_dcc(B, U, expand_U, S, code_length, gamma) # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma) logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time)) logger.info('Training adsh finish, time:{:.2f}'.format(time.time()-total_time)) # Save checkpoints torch.save(B.cpu(), os.path.join('checkpoints', 'old_B.t')) # Evaluate query_code = generate_code(model, query_dataloader, code_length, device) mAP = evaluate.mean_average_precision( query_code.to(device), B, query_dataloader.dataset.get_onehot_targets().to(device), retrieval_dataloader.dataset.get_onehot_targets().to(device), device, topk, ) logger.info('[ADSH map:{:.4f}]'.format(mAP))
def train( train_data, query_data, query_targets, retrieval_data, retrieval_targets, code_length, device, topk, ): """ Training model. Args train_data(torch.Tensor): Training data. query_data(torch.Tensor): Query data. query_targets(torch.Tensor): Query targets. retrieval_data(torch.Tensor): Retrieval data. retrieval_targets(torch.Tensor): Retrieval targets. code_length(int): Hash code length. device(torch.device): GPU or CPU. topk(int): Calculate top k data points map. Returns checkpoint(dict): Checkpoint. """ # PCA pca = PCA(n_components=code_length) X = pca.fit_transform(train_data.numpy()) # Fit uniform distribution eps = np.finfo(float).eps mn = X.min(0) - eps mx = X.max(0) + eps # Enumerate eigenfunctions R = mx - mn max_mode = np.ceil((code_length + 1) * R / R.max()).astype(np.int) n_modes = max_mode.sum() - len(max_mode) + 1 modes = np.ones([n_modes, code_length]) m = 0 for i in range(code_length): modes[m + 1:m + max_mode[i], i] = np.arange(1, max_mode[i]) + 1 m = m + max_mode[i] - 1 modes -= 1 omega0 = np.pi / R omegas = modes * omega0.reshape(1, -1).repeat(n_modes, 0) eig_val = -(omegas**2).sum(1) ii = (-eig_val).argsort() modes = modes[ii[1:code_length + 1], :] # Evaluate # Generate query code and retrieval code query_code = generate_code(query_data.cpu(), code_length, pca, mn, R, modes).to(device) retrieval_code = generate_code(retrieval_data.cpu(), code_length, pca, mn, R, modes).to(device) query_targets = query_targets.to(device) retrieval_targets = retrieval_targets.to(device) # Compute map mAP = mean_average_precision( query_code, retrieval_code, query_targets, retrieval_targets, device, topk, ) # P-R curve P, Recall = pr_curve( query_code, retrieval_code, query_targets, retrieval_targets, device, ) # Save checkpoint checkpoint = { 'qB': query_code, 'rB': retrieval_code, 'qL': query_targets, 'rL': retrieval_targets, 'P': P, 'R': Recall, 'map': mAP, } return checkpoint