def load_model(code_length, device): path = './checkpoints/adsh_nuswide_checkpoints/ADSH_NUSWIDE_48bits.pt' M = torch.load(path, map_location=lambda storage, loc: storage) B = M['rB'] model = alexnet.load_model(code_length).to(device) model.load_state_dict(M['model']) model.eval() # load imgrecord root = '../datasets/NUS-WIDE' img_txt = 'database_img.txt' img_txt_path = os.path.join(root, img_txt) with open(img_txt_path, 'r') as f: imgrecord = np.array([i.strip() for i in f]) return B, imgrecord, model
def load_model(arch, code_length): """ Load CNN model. Args arch(str): CNN model name. code_length(int): Hash code length. Returns model(torch.nn.Module): CNN model. """ if arch == 'alexnet': model = alexnet.load_model(code_length) elif arch == 'vgg16': model = vgg16.load_model(code_length) else: raise ValueError('Invalid cnn model name!') return model
def load_model(name, pretrained=True, num_classes=None): """加载模型 Parameters name: str 模型名称 pretrained: bool True: 加载预训练模型; False: 加载未训练模型 num_classes: int CNN最后一层输出类别 Returns model: model 模型 """ if name == 'alexnet': return alexnet.load_model(pretrained=pretrained, num_classes=num_classes)
def train( query_dataloader, retrieval_dataloader, code_length, device, lr, max_iter, max_epoch, num_samples, batch_size, root, dataset, parameters, topk, ): """ Training model. Args query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. batch_size(int): Batch size. root(str): Path of dataset. dataset(str): Dataset name. gamma(float): Hyper-parameters. topk(int): Topk k map. Returns mAP(float): Mean Average Precision. """ # parameters = {'eta':2, 'mu':0.2, 'gamma':1, 'varphi':200} eta = parameters['eta'] mu = parameters['mu'] gamma = parameters['gamma'] varphi = parameters['varphi'] # Initialization model = alexnet.load_model(code_length).to(device) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) criterion = DADSH2_Loss(code_length, eta, mu, gamma, varphi, device) num_retrieval = len(retrieval_dataloader.dataset) U = torch.zeros(num_samples, code_length).to(device) # U (m*l, l:code_length) B = torch.randn(num_retrieval, code_length).to(device) # V (n*l, l:code_length) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to( device) # Y2 (n*c, c:classes) start = time.time() for it in range(max_iter): iter_start = time.time() # Sample training data for cnn learning train_dataloader, sample_index = sample_dataloader( retrieval_dataloader, num_samples, batch_size, root, dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to( device) # Y1 (m*c, c:classes) S = (train_targets @ retrieval_targets.t() > 0).float() # S (m*n, c:classes) S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(max_epoch): for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(device), targets.to( device), index.to(device) optimizer.zero_grad() F = model(data) # output (m1*l) U[index, :] = F.data cnn_loss = criterion(F, B, S[index, :], sample_index[index]) cnn_loss.backward() optimizer.step() # Update B expand_U = torch.zeros(B.shape).to(device) expand_U[sample_index, :] = U B = solve_dcc(B, U, expand_U, S, code_length, varphi) # Total loss # iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma) # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, max_iter, iter_loss, time.time()-iter_start)) logger.info('[Training time:{:.2f}]'.format(time.time() - start)) # Evaluate query_code = generate_code(model, query_dataloader, code_length, device) mAP = evaluate.mean_average_precision( query_code.to(device), B, query_dataloader.dataset.get_onehot_targets().to(device), retrieval_targets, device, topk, ) # Save checkpoints torch.save(query_code.cpu(), os.path.join('checkpoints', 'query_code.t')) torch.save(B.cpu(), os.path.join('checkpoints', 'database_code.t')) torch.save(query_dataloader.dataset.get_onehot_targets, os.path.join('checkpoints', 'query_targets.t')) torch.save(retrieval_targets.cpu(), os.path.join('checkpoints', 'database_targets.t')) torch.save(model.cpu(), os.path.join('checkpoints', 'model.t')) return mAP
def increment( query_dataloader, unseen_dataloader, retrieval_dataloader, old_B, code_length, device, lr, max_iter, max_epoch, num_samples, batch_size, root, dataset, gamma, mu, topk, ): """ Increment model. Args query_dataloader, unseen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. old_B(torch.Tensor): Old binary hash code. code_length(int): Hash code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. max_epoch(int): Number of epochs. num_train(int): Number of sampling training data points. batch_size(int): Batch size. root(str): Path of dataset. dataset(str): Dataset name. gamma, mu(float): Hyper-parameters. topk(int): Top k map. Returns mAP(float): Mean Average Precision. """ # Initialization model = alexnet.load_model(code_length) model.to(device) model.train() optimizer = optim.Adam( model.parameters(), lr=lr, weight_decay=1e-5, ) criterion = DIHN_Loss(code_length, gamma, mu) lr_scheduler = ExponentialLR(optimizer, 0.91) num_unseen = len(unseen_dataloader.dataset) num_seen = len(old_B) U = torch.zeros(num_samples, code_length).to(device) old_B = old_B.to(device) new_B = torch.randn(num_unseen, code_length).sign().to(device) B = torch.cat((old_B, new_B), dim=0).to(device) retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(device) total_time = time.time() for it in range(max_iter): iter_time = time.time() lr_scheduler.step() # Sample training data for cnn learning train_dataloader, sample_index, unseen_sample_in_unseen_index, unseen_sample_in_sample_index = sample_dataloader(retrieval_dataloader, num_samples, batch_size, root, dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to(device) S = (train_targets @ retrieval_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(max_epoch): for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() F = model(data) U[index, :] = F.data cnn_loss = criterion(F, B, S[index, :], index) cnn_loss.backward() optimizer.step() # Update B expand_U = torch.zeros(num_unseen, code_length).to(device) expand_U[unseen_sample_in_unseen_index, :] = U[unseen_sample_in_sample_index, :] new_B = solve_dcc(new_B, U, expand_U, S[:, unseen_dataloader.dataset.UNSEEN_INDEX], code_length, gamma) B = torch.cat((old_B, new_B), dim=0).to(device) # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma, mu) logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time)) logger.info('[DIHN time:{:.2f}]'.format(time.time() - total_time)) # Evaluate query_code = generate_code(model, query_dataloader, code_length, device) mAP = evaluate.mean_average_precision( query_code.to(device), B, query_dataloader.dataset.get_onehot_targets().to(device), retrieval_targets, device, topk, ) return mAP
def train( query_dataloader, seen_dataloader, retrieval_dataloader, code_length, device, lr, max_iter, max_epoch, num_samples, batch_size, root, dataset, gamma, topk, ): """ Training model. Args query_dataloader, seen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. code_length(int): Hashing code length. device(torch.device): GPU or CPU. lr(float): Learning rate. max_iter(int): Number of iterations. max_epoch(int): Number of epochs. num_samples(int): Number of sampling training data points. batch_size(int): Batch size. root(str): Path of dataset. dataset(str): Dataset name. gamma(float): Hyper-parameters. topk(int): Topk k map. Returns None """ # Initialization model = alexnet.load_model(code_length).to(device) optimizer = optim.Adam( model.parameters(), lr=lr, weight_decay=1e-5, ) criterion = ADSH_Loss(code_length, gamma) lr_scheduler = ExponentialLR(optimizer, 0.9) num_seen = len(seen_dataloader.dataset) U = torch.zeros(num_samples, code_length).to(device) B = torch.randn(num_seen, code_length).sign().to(device) seen_targets = seen_dataloader.dataset.get_onehot_targets().to(device) total_time = time.time() for it in range(max_iter): iter_time = time.time() lr_scheduler.step() # Sample training data for cnn learning train_dataloader, sample_index, _, _ = sample_dataloader(seen_dataloader, num_samples, batch_size, root, dataset) # Create Similarity matrix train_targets = train_dataloader.dataset.get_onehot_targets().to(device) S = (train_targets @ seen_targets.t() > 0).float() S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1)) # Soft similarity matrix, benefit to converge r = S.sum() / (1 - S).sum() S = S * (1 + r) - r # Training CNN model for epoch in range(max_epoch): for batch, (data, targets, index) in enumerate(train_dataloader): data, targets, index = data.to(device), targets.to(device), index.to(device) optimizer.zero_grad() F = model(data) U[index, :] = F.data cnn_loss = criterion(F, B, S[index, :], index) cnn_loss.backward() optimizer.step() # Update B expand_U = torch.zeros(B.shape).to(device) expand_U[sample_index, :] = U B = solve_dcc(B, U, expand_U, S, code_length, gamma) # Total loss iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma) logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time)) logger.info('Training adsh finish, time:{:.2f}'.format(time.time()-total_time)) # Save checkpoints torch.save(B.cpu(), os.path.join('checkpoints', 'old_B.t')) # Evaluate query_code = generate_code(model, query_dataloader, code_length, device) mAP = evaluate.mean_average_precision( query_code.to(device), B, query_dataloader.dataset.get_onehot_targets().to(device), retrieval_dataloader.dataset.get_onehot_targets().to(device), device, topk, ) logger.info('[ADSH map:{:.4f}]'.format(mAP))