def train( query_dataloader, retrieval_dataloader, code_length, args, # args.device, # lr, # args.max_iter, # args.max_epoch, # args.num_samples, # args.batch_size, # args.root, # dataset, # args.gamma, # topk, ): """ Training model. Args query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. args.device(torch.args.device): GPU or CPU. lr(float): Learning rate. args.max_iter(int): Number of iterations. args.max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. args.batch_size(int): Batch size. args.root(str): Path of dataset. dataset(str): Dataset name. args.gamma(float): Hyper-parameters. topk(int): Topk k map. Returns mAP(float): Mean Average Precision. """ # Initialization # model = alexnet.load_model(code_length).to(args.device) model = resnet.resnet50(pretrained=args.pretrain, num_classes=code_length).to(args.device) if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momen, nesterov=args.nesterov) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_step) criterion = ADSH_Loss(code_length, args.gamma) num_retrieval = len(retrieval_dataloader.dataset) U = torch.zeros(args.num_samples, code_length).to(args.device) B = torch.randn(num_retrieval, code_length).to(args.device) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to( args.device) cnn_losses, hash_losses, quan_losses = AverageMeter(), AverageMeter( ), AverageMeter() start = time.time() best_mAP = 0 for it in range(args.max_iter): iter_start = time.time() # Sample training data for cnn learning train_dataloader, sample_index = sample_dataloader( retrieval_dataloader, args.num_samples, args.batch_size, args.root, args.dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to( args.device) S = (train_targets @ retrieval_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(args.max_epoch): cnn_losses.reset() hash_losses.reset() quan_losses.reset() for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(args.device), targets.to( args.device), index.to(args.device) optimizer.zero_grad() F = model(data) U[index, :] = F.data cnn_loss, hash_loss, quan_loss = criterion( F, B, S[index, :], sample_index[index]) cnn_losses.update(cnn_loss.item()) hash_losses.update(hash_loss.item()) quan_losses.update(quan_loss.item()) cnn_loss.backward() optimizer.step() logger.info( '[epoch:{}/{}][cnn_loss:{:.6f}][hash_loss:{:.6f}][quan_loss:{:.6f}]' .format(epoch + 1, args.max_epoch, cnn_losses.avg, hash_losses.avg, quan_losses.avg)) scheduler.step() # Update B expand_U = torch.zeros(B.shape).to(args.device) expand_U[sample_index, :] = U B = solve_dcc(B, U, expand_U, S, code_length, args.gamma) # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, args.gamma) # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, args.max_iter, iter_loss, time.time()-iter_start)) logger.info('[iter:{}/{}][loss:{:.6f}][iter_time:{:.2f}]'.format( it + 1, args.max_iter, iter_loss, time.time() - iter_start)) # Evaluate if (it + 1) % 1 == 0: query_code = generate_code(model, query_dataloader, code_length, args.device) mAP = evaluate.mean_average_precision( query_code.to(args.device), B, query_dataloader.dataset.get_onehot_targets().to(args.device), retrieval_targets, args.device, args.topk, ) if mAP > best_mAP: best_mAP = mAP # Save checkpoints ret_path = os.path.join('checkpoints', args.info, str(code_length)) # ret_path = 'checkpoints/' + args.info if not os.path.exists(ret_path): os.makedirs(ret_path) torch.save(query_code.cpu(), os.path.join(ret_path, 'query_code.t')) torch.save(B.cpu(), os.path.join(ret_path, 'database_code.t')) torch.save(query_dataloader.dataset.get_onehot_targets, os.path.join(ret_path, 'query_targets.t')) torch.save(retrieval_targets.cpu(), os.path.join(ret_path, 'database_targets.t')) torch.save(model.cpu(), os.path.join(ret_path, 'model.t')) model = model.to(args.device) logger.info( '[iter:{}/{}][code_length:{}][mAP:{:.5f}][best_mAP:{:.5f}]'. format(it + 1, args.max_iter, code_length, mAP, best_mAP)) logger.info('[Training time:{:.2f}]'.format(time.time() - start)) return best_mAP
def train( query_dataloader, retrieval_dataloader, code_length, device, lr, max_iter, max_epoch, num_samples, batch_size, root, dataset, parameters, topk, ): """ Training model. Args query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. batch_size(int): Batch size. root(str): Path of dataset. dataset(str): Dataset name. gamma(float): Hyper-parameters. topk(int): Topk k map. Returns mAP(float): Mean Average Precision. """ # parameters = {'eta':2, 'mu':0.2, 'gamma':1, 'varphi':200} eta = parameters['eta'] mu = parameters['mu'] gamma = parameters['gamma'] varphi = parameters['varphi'] # Initialization model = alexnet.load_model(code_length).to(device) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) criterion = DADSH2_Loss(code_length, eta, mu, gamma, varphi, device) num_retrieval = len(retrieval_dataloader.dataset) U = torch.zeros(num_samples, code_length).to(device) # U (m*l, l:code_length) B = torch.randn(num_retrieval, code_length).to(device) # V (n*l, l:code_length) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to( device) # Y2 (n*c, c:classes) start = time.time() for it in range(max_iter): iter_start = time.time() # Sample training data for cnn learning train_dataloader, sample_index = sample_dataloader( retrieval_dataloader, num_samples, batch_size, root, dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to( device) # Y1 (m*c, c:classes) S = (train_targets @ retrieval_targets.t() > 0).float() # S (m*n, c:classes) S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(max_epoch): for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(device), targets.to( device), index.to(device) optimizer.zero_grad() F = model(data) # output (m1*l) U[index, :] = F.data cnn_loss = criterion(F, B, S[index, :], sample_index[index]) cnn_loss.backward() optimizer.step() # Update B expand_U = torch.zeros(B.shape).to(device) expand_U[sample_index, :] = U B = solve_dcc(B, U, expand_U, S, code_length, varphi) # Total loss # iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma) # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, max_iter, iter_loss, time.time()-iter_start)) logger.info('[Training time:{:.2f}]'.format(time.time() - start)) # Evaluate query_code = generate_code(model, query_dataloader, code_length, device) mAP = evaluate.mean_average_precision( query_code.to(device), B, query_dataloader.dataset.get_onehot_targets().to(device), retrieval_targets, device, topk, ) # Save checkpoints torch.save(query_code.cpu(), os.path.join('checkpoints', 'query_code.t')) torch.save(B.cpu(), os.path.join('checkpoints', 'database_code.t')) torch.save(query_dataloader.dataset.get_onehot_targets, os.path.join('checkpoints', 'query_targets.t')) torch.save(retrieval_targets.cpu(), os.path.join('checkpoints', 'database_targets.t')) torch.save(model.cpu(), os.path.join('checkpoints', 'model.t')) return mAP
def increment( query_dataloader, unseen_dataloader, retrieval_dataloader, old_B, code_length, device, lr, max_iter, max_epoch, num_samples, batch_size, root, dataset, gamma, mu, topk, ): """ Increment model. Args query_dataloader, unseen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. old_B(torch.Tensor): Old binary hash code. code_length(int): Hash code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. batch_size(int): Batch size. root(str): Path of dataset. dataset(str): Dataset name. gamma, mu(float): Hyper-parameters. topk(int): Top k map. Returns mAP(float): Mean Average Precision. """ # Initialization model = alexnet.load_model(code_length) model.to(device) model.train() optimizer = optim.Adam( model.parameters(), lr=lr, weight_decay=1e-5, ) criterion = DIHN_Loss(code_length, gamma, mu) lr_scheduler = ExponentialLR(optimizer, 0.91) num_unseen = len(unseen_dataloader.dataset) num_seen = len(old_B) U = torch.zeros(num_samples, code_length).to(device) old_B = old_B.to(device) new_B = torch.randn(num_unseen, code_length).sign().to(device) B = torch.cat((old_B, new_B), dim=0).to(device) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(device) total_time = time.time() for it in range(max_iter): iter_time = time.time() lr_scheduler.step() # Sample training data for cnn learning train_dataloader, sample_index, unseen_sample_in_unseen_index, unseen_sample_in_sample_index = sample_dataloader(retrieval_dataloader, num_samples, batch_size, root, dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to(device) S = (train_targets @ retrieval_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(max_epoch): for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() F = model(data) U[index, :] = F.data cnn_loss = criterion(F, B, S[index, :], index) cnn_loss.backward() optimizer.step() # Update B expand_U = torch.zeros(num_unseen, code_length).to(device) expand_U[unseen_sample_in_unseen_index, :] = U[unseen_sample_in_sample_index, :] new_B = solve_dcc(new_B, U, expand_U, S[:, unseen_dataloader.dataset.UNSEEN_INDEX], code_length, gamma) B = torch.cat((old_B, new_B), dim=0).to(device) # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma, mu) logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time)) logger.info('[DIHN time:{:.2f}]'.format(time.time() - total_time)) # Evaluate query_code = generate_code(model, query_dataloader, code_length, device) mAP = evaluate.mean_average_precision( query_code.to(device), B, query_dataloader.dataset.get_onehot_targets().to(device), retrieval_targets, device, topk, ) return mAP
def train( query_dataloader, train_dataloader, retrieval_dataloader, code_length, args # device, # lr, # args.max_iter, # args.max_epoch, # args.num_samples, # args.batch_size, # args.root, # dataset, # args.gamma, # args.topk, ): """ Training model. Args query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. device(torch.device): GPU or CPU. lr(float): Learning rate. args.max_iter(int): Number of iterations. args.max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. args.batch_size(int): Batch size. args.root(str): Path of dataset. dataset(str): Dataset name. args.gamma(float): Hyper-parameters. args.topk(int): args.Topk k map. Returns mAP(float): Mean Average Precision. """ # Initialization # model = alexnet.load_model(code_length).to(device) # model = resnet.resnet50(pretrained=True, num_classes=code_length).to(device) num_classes, att_size, feat_size = args.num_classes, 4, 2048 model = exchnet.exchnet(code_length=code_length, num_classes=num_classes, att_size=att_size, feat_size=feat_size, device=args.device, pretrained=args.pretrain).to(args.device) if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momen, nesterov=args.nesterov) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_step) criterion = Exch_Loss(code_length, args.device, lambd_sp=1.0, lambd_ch=1.0) criterion.quanting = args.quan_loss num_retrieval = len(retrieval_dataloader.dataset) U = torch.zeros(args.num_samples, code_length).to(args.device) B = torch.randn(num_retrieval, code_length).to(args.device) # B = torch.zeros(num_retrieval, code_length).to(args.device) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to( args.device) C = torch.zeros((num_classes, att_size, feat_size)).to(args.device) start = time.time() best_mAP = 0 for it in range(args.max_iter): iter_start = time.time() # Sample training data for cnn learning train_dataloader, sample_index = sample_dataloader( retrieval_dataloader, args.num_samples, args.batch_size, args.root, args.dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to( args.device) S = (train_targets @ retrieval_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r cnn_losses, hash_losses, quan_losses, sp_losses, ch_losses, align_losses = AverageMeter(), \ AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() # Training CNN model for epoch in range(args.max_epoch): cnn_losses.reset() hash_losses.reset() quan_losses.reset() sp_losses.reset() ch_losses.reset() align_losses.reset() for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(args.device), targets.to( args.device), index.to(args.device) optimizer.zero_grad() F, sp_v, ch_v, avg_local_f = model(data, targets) U[index, :] = F.data batch_anchor_local_f = C[torch.argmax(targets, dim=1)] # print(index) cnn_loss, hash_loss, quan_loss, sp_loss, ch_loss, align_loss = criterion( F, B, S[index, :], sample_index[index], sp_v, ch_v, avg_local_f, batch_anchor_local_f) cnn_losses.update(cnn_loss.item()) hash_losses.update(hash_loss.item()) quan_losses.update(quan_loss.item()) sp_losses.update(sp_loss.item()) ch_losses.update(ch_loss.item()) align_losses.update(align_loss.item()) # print(ch_v) cnn_loss.backward() optimizer.step() logger.info( '[epoch:{}/{}][cnn_loss:{:.6f}][h_loss:{:.6f}][q_loss:{:.6f}][s_loss:{:.4f}][c_loss:{:.4f}][a_loss:{:.4f}]' .format(epoch + 1, args.max_epoch, cnn_losses.avg, hash_losses.avg, quan_losses.avg, sp_losses.avg, ch_losses.avg, align_losses.avg)) scheduler.step() # Update B expand_U = torch.zeros(B.shape).to(args.device) expand_U[sample_index, :] = U if args.quan_loss: B = solve_dcc_adsh(B, U, expand_U, S, code_length, args.gamma) else: B = solve_dcc_exch(B, U, expand_U, S, code_length, args.gamma) # Update C if (it + 1) >= args.align_step: model.exchanging = True criterion.aligning = True model.eval() with torch.no_grad(): C = torch.zeros( (num_classes, att_size, feat_size)).to(args.device) feat_cnt = torch.zeros((num_classes, 1, 1)).to(args.device) for batch, (data, targets, index) in enumerate(retrieval_dataloader): data, targets, index = data.to(args.device), targets.to( args.device), index.to(args.device) _, _, _, avg_local_f = model(data, targets) class_idx = targets.argmax(dim=1) for i in range(targets.shape[0]): C[class_idx[i]] += avg_local_f[i] feat_cnt[class_idx[i]] += 1 C /= feat_cnt model.anchor_local_f = C model.train() # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, args.gamma) # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, args.max_iter, iter_loss, time.time()-iter_start)) logger.info('[iter:{}/{}][loss:{:.6f}][iter_time:{:.2f}]'.format( it + 1, args.max_iter, iter_loss, time.time() - iter_start)) # Evaluate if (it + 1) % 1 == 0: query_code = generate_code(model, query_dataloader, code_length, args.device) mAP = evaluate.mean_average_precision( query_code.to(args.device), B, query_dataloader.dataset.get_onehot_targets().to(args.device), retrieval_targets, args.device, args.topk, ) if mAP > best_mAP: best_mAP = mAP # Save checkpoints ret_path = os.path.join('checkpoints', args.info, str(code_length)) if not os.path.exists(ret_path): os.makedirs(ret_path) torch.save(query_code.cpu(), os.path.join(ret_path, 'query_code.t')) torch.save(B.cpu(), os.path.join(ret_path, 'database_code.t')) torch.save(query_dataloader.dataset.get_onehot_targets, os.path.join(ret_path, 'query_targets.t')) torch.save(retrieval_targets.cpu(), os.path.join(ret_path, 'database_targets.t')) torch.save(model.cpu(), os.path.join(ret_path, 'model.t')) model = model.to(args.device) logger.info( '[iter:{}/{}][code_length:{}][mAP:{:.5f}][best_mAP:{:.5f}]'. format(it + 1, args.max_iter, code_length, mAP, best_mAP)) logger.info('[Training time:{:.2f}]'.format(time.time() - start)) return best_mAP
def train( query_dataloader, seen_dataloader, retrieval_dataloader, code_length, device, lr, max_iter, max_epoch, num_samples, batch_size, root, dataset, gamma, topk, ): """ Training model. Args query_dataloader, seen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. max_epoch(int): Number of epochs. num_samples(int): Number of sampling training data points. batch_size(int): Batch size. root(str): Path of dataset. dataset(str): Dataset name. gamma(float): Hyper-parameters. topk(int): Topk k map. Returns None """ # Initialization model = alexnet.load_model(code_length).to(device) optimizer = optim.Adam( model.parameters(), lr=lr, weight_decay=1e-5, ) criterion = ADSH_Loss(code_length, gamma) lr_scheduler = ExponentialLR(optimizer, 0.9) num_seen = len(seen_dataloader.dataset) U = torch.zeros(num_samples, code_length).to(device) B = torch.randn(num_seen, code_length).sign().to(device) seen_targets = seen_dataloader.dataset.get_onehot_targets().to(device) total_time = time.time() for it in range(max_iter): iter_time = time.time() lr_scheduler.step() # Sample training data for cnn learning train_dataloader, sample_index, _, _ = sample_dataloader(seen_dataloader, num_samples, batch_size, root, dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to(device) S = (train_targets @ seen_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(max_epoch): for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() F = model(data) U[index, :] = F.data cnn_loss = criterion(F, B, S[index, :], index) cnn_loss.backward() optimizer.step() # Update B expand_U = torch.zeros(B.shape).to(device) expand_U[sample_index, :] = U B = solve_dcc(B, U, expand_U, S, code_length, gamma) # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma) logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time)) logger.info('Training adsh finish, time:{:.2f}'.format(time.time()-total_time)) # Save checkpoints torch.save(B.cpu(), os.path.join('checkpoints', 'old_B.t')) # Evaluate query_code = generate_code(model, query_dataloader, code_length, device) mAP = evaluate.mean_average_precision( query_code.to(device), B, query_dataloader.dataset.get_onehot_targets().to(device), retrieval_dataloader.dataset.get_onehot_targets().to(device), device, topk, ) logger.info('[ADSH map:{:.4f}]'.format(mAP))