def test(self, heatmap=True): """ Test the model on the held-out test data. """ # load the best checkpoint self.load_checkpoint(best=True) self.model.eval() with torch.no_grad(): losses = [] accuracies = [] confusion_matrix = torch.zeros(self.num_classes, self.num_classes) heat_points = [[[] for i in range(self.num_glimpses)] for j in range(self.num_digits + 1)] heat_mean = torch.zeros(self.num_digits, *self.im_size) var_means = torch.zeros(self.num_glimpses, len(self.test_loader)) loss_means = torch.zeros(self.num_glimpses, len(self.test_loader)) count_hits = torch.zeros(self.num_glimpses, 2) var_dist = torch.zeros(len(self.test_loader.dataset), self.num_glimpses, self.im_size[-1], self.im_size[-1]) err_dist = torch.zeros(len(self.test_loader.dataset), self.num_glimpses, self.im_size[-1], self.im_size[-1]) labels = torch.zeros(len(self.test_loader.dataset)) dist = PairwiseDistance(2) dists = torch.zeros(len(self.test_loader.dataset), self.num_glimpses) run_idx = 0 for i, (_, x, y) in enumerate(self.test_loader): if self.use_gpu: x = x.cuda() y = y.cuda() true = y.detach().argmax(1) # initialize location vector and hidden state self.batch_size = x.shape[0] s_t, _, l_t = self.reset() targets = torch.zeros(self.num_glimpses, 3, *x.shape[2:]) glimpses = torch.zeros(self.num_glimpses, 3, *x.shape[2:]) means = torch.zeros(self.num_glimpses, 3, *x.shape[2:]) vars = torch.zeros(self.num_glimpses, 3, *x.shape[2:]) locs = torch.zeros(self.batch_size, self.num_glimpses, 2) acc_loss = 0 sub_dir = self.plot_dir / f"{i:06d}_dir" if not sub_dir.exists() and i < 10: sub_dir.mkdir(parents=True) for t in range(self.num_glimpses): # forward pass through model locs[:, t] = l_t.clone().detach() s_t, l_pred, r_t, out_t, mu, logvar = self.model( x, l_t, s_t, samples=self.samples, classify=self.classify, ) draw = x[0].expand(3, -1, -1) extent = self.patch_size for patch in range(self.num_patches): draw = draw_glimpse( draw, l_t[0], extent=extent, ) extent *= self.glimpse_scale targets[t] = draw glimpses[t] = self.model.sensor.glimpse[0][0].expand( 3, -1, -1) means[t] = r_t[0][0].expand(3, -1, -1) vars[t] = convert_heatmap( r_t.var(0)[0].squeeze().cpu().numpy()) if t == 0: var_first = r_t.var(0).squeeze().cpu().numpy() var_last = r_t.var(0).squeeze().cpu().numpy() loc_prev = loc = denormalize(x.size(-1), l_t) l_t = get_loc_target(r_t.var(0), max=self.use_amax).type(l_t.type()) loc = denormalize(x.size(-1), l_t) dists[run_idx:run_idx + len(x), t] = dist.forward(loc_prev.float(), loc.float()).detach().cpu() var_dist[run_idx:run_idx + len(x), t] = r_t.var(0).detach().cpu().squeeze() temp_loss = torch.zeros(self.batch_size, *err_dist.shape[2:]) for s in range(self.samples): temp_loss += F.binary_cross_entropy( r_t[s], x, reduction='none').detach().cpu( ).squeeze() / self.samples err_dist[run_idx:run_idx + len(x), t] = temp_loss for k, l in enumerate(loc): if x.squeeze()[k][l[0], l[1]] > 0: count_hits[t, 0] += 1 else: count_hits[t, 1] += 1 imgs = [targets[t], glimpses[t], means[t], vars[t]] filename = sub_dir / f"{i:06d}_glimpse{t:01d}.png" if i < 10: torchvision.utils.save_image( imgs, filename, nrow=1, normalize=True, pad_value=1, padding=1, scale_each=True, ) var = r_t.var(0).reshape(self.batch_size, -1) var_means[t, i] = var.max(-1)[0].mean(0) loss_means[t, i] = temp_loss.mean() for j in range(self.num_digits): heat_points[j][t].extend( denormalize(x.size(-1), l_t[true == j]).tolist()) heat_points[-1][t].extend(loc.tolist()) labels[run_idx:run_idx + len(true)] = true run_idx += len(x) for s in range(self.samples): loss = F.mse_loss(r_t[s], x) / self.samples acc_loss += loss.item() # store if self.classify: if self.num_classes == 1: pred = out_t.detach().squeeze().round() target = (y.argmax(1) == self.target_class).float() else: pred = out_t.detach().argmax(1) target = true for p, tr in zip(pred.view(-1), target.view(-1)): confusion_matrix[tr.long(), p.long()] += 1 accuracies.append( torch.sum(pred == target).float().item() / len(y)) for j in range(self.num_digits): if len(x[true == j]): heat_mean[j] += x[true == j].mean(0).cpu() losses.append(acc_loss) imgs = [targets, glimpses, means, vars] for im in imgs: im = (im - im.min()) / (im.max() - im.min()) # store images + reconstructions of largest scale filename = self.plot_dir / f"{i:06d}.png" if i < 10: torchvision.utils.save_image( torch.cat(imgs, 0), filename, nrow=6, normalize=False, pad_value=1, padding=3, scale_each=False, ) pkl.dump(dists, open(self.file_dir / 'dists.pkl', 'wb')) pkl.dump(var_dist, open(self.file_dir / 'var_dist.pkl', 'wb')) pkl.dump(err_dist, open(self.file_dir / 'err_dist.pkl', 'wb')) pkl.dump(heat_points, open(self.file_dir / 'locs.pkl', 'wb')) pkl.dump(labels, open(self.file_dir / 'labels.pkl', 'wb')) pkl.dump(heat_mean, open(self.file_dir / 'mean.pkl', 'wb')) print(count_hits[:, 0] / count_hits.sum(1)) sn.set(font="serif", font_scale=2, context='paper', style='dark', rc={"lines.linewidth": 2.5}) errs = err_dist.view(len(self.test_loader.dataset), self.num_glimpses, -1).cpu().numpy() vars = var_dist.view(len(self.test_loader.dataset), self.num_glimpses, -1).cpu().numpy() f, ax = plt.subplots(2, 1, sharex=True, figsize=(9, 12)) ax[0].plot(range(self.num_glimpses), errs.mean((0, -1)), marker='o', markersize=10, c=PLOT_COLOR) ax[0].set_ylabel('Prediction Error') ax[1].plot(range(self.num_glimpses), vars.max(-1).mean(0), marker='o', markersize=10, c=PLOT_COLOR) ax[1].set_xlabel('Saccade number') ax[1].set_ylabel('Uncertainty') f.tight_layout() f.savefig(self.plot_dir / f"comb_var_err.pdf", bbox_inches='tight', pad_inches=0, dpi=600) print("#######################") if self.classify: print(confusion_matrix) print(confusion_matrix.diag() / confusion_matrix.sum(1)) plot_confusion_matrix(confusion_matrix, self.plot_dir / f"confusion_matrix.png") pkl.dump(confusion_matrix, open(self.file_dir / 'conf_matrix.pkl', 'wb')) #create heatmaps flatten = lambda l, i: [ item for sublist in l for item in sublist[i] ] for i in range(self.num_digits): img = array2img(heat_mean[i]) points = np.array(heat_points[i]) first = points[:3].reshape(-1, 2) heatmapper = Heatmapper( point_diameter=1, # the size of each point to be drawn point_strength= 0.3, # the strength, between 0 and 1, of each point to be drawn opacity=0.85, ) heatmap = heatmapper.heatmap_on_img(first, img) heatmap.save(self.plot_dir / f"heatmap_bef{i}.png") last = points[3:-1].reshape(-1, 2) heatmap = heatmapper.heatmap_on_img(last, img) heatmap.save(self.plot_dir / f"heatmap_aft{i}.png") for j in range(self.num_glimpses): heatmap = heatmapper.heatmap_on_img(heat_points[i][j], img) heatmap.save(self.plot_dir / f"heatmap_class{i}_glimpse{j}.png") self.logger.info( f"[*] Test loss: {np.mean(losses)}, Test accuracy: {np.mean(accuracies)}" ) return np.mean(losses)
end='') #f, axarr = plt.subplots(1,3) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() anc_fv = model(data['anc_img'].cuda()) pos_fv = model(data['pos_img'].cuda()) neg_fv = model(data['neg_img'].cuda()) #print("Batch Size :",anc_fv.shape[0]) pos_dis = l2_distance.forward(anc_fv, pos_fv) neg_dis = l2_distance.forward(anc_fv, neg_fv) all = (pos_dis < neg_dis).cpu().numpy().flatten() losses = torch.clamp(pos_dis[all] - neg_dis[all] + margin, min=0.0) loss = torch.mean(losses) #print(losses) print(" Len :", len(pos_dis[all]), end='') print(" Loss :", loss.item()) optimizer.zero_grad() loss.backward() optimizer.step()
def main(): dataroot = args.dataroot apd_dataroot = args.apd dataset_csv = args.dataset_csv apd_batch_size = args.apd_batch_size apd_validation_epoch_interval = args.apd_validation_epoch_interval model_architecture = args.model epochs = args.epochs training_triplets_path = args.training_triplets_path num_triplets_train = args.num_triplets_train resume_path = args.resume_path batch_size = args.batch_size num_workers = args.num_workers embedding_dimension = args.embedding_dim pretrained = args.pretrained optimizer = args.optimizer learning_rate = args.lr margin = args.margin start_epoch = 0 # Define image data pre-processing transforms # ToTensor() normalizes pixel values between [0, 1] # Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalizes pixel values between [-1, 1] # Size 182x182 RGB image -> Center crop size 160x160 RGB image for more model generalization data_transforms = transforms.Compose([ #transforms.RandomCrop(size=10), transforms.Resize(size=(50, 50)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Size 160x160 RGB image apd_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) logging.info('prepare VGGNet dataloader...') # Set dataloaders train_dataloader = torch.utils.data.DataLoader(dataset=TripletFaceDataset( root_dir=dataroot, csv_name=dataset_csv, num_triplets=num_triplets_train, training_triplets_path=training_triplets_path, transform=data_transforms), batch_size=batch_size, num_workers=num_workers, shuffle=False) logging.info('prepare APD dataset') apd_dataroot = '../APD' apdDataset = APDDataset( directory=apd_dataroot, #pairs_path='negative_pairs.txt', #pairs_path='positive_pairs.txt', pairs_path='full.txt', transform=apd_transforms) print('apdDataset', apdDataset) logging.info('prepare apd loader') apd_dataloader = torch.utils.data.DataLoader(dataset=apdDataset, batch_size=apd_batch_size, num_workers=num_workers, shuffle=True) print('apd_dataloader', apd_dataloader) logging.info('prepare model') # Instantiate model if model_architecture == "resnet18": model = Resnet18Triplet(embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "resnet34": model = Resnet34Triplet(embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "resnet50": model = Resnet50Triplet(embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "resnet101": model = Resnet101Triplet(embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "inceptionresnetv2": model = InceptionResnetV2Triplet( embedding_dimension=embedding_dimension, pretrained=pretrained) print("Using {} model architecture.".format(model_architecture)) # Load model to GPU or multiple GPUs if available flag_train_gpu = torch.cuda.is_available() flag_train_multi_gpu = False if flag_train_gpu and torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.cuda() flag_train_multi_gpu = True print('Using multi-gpu training.') elif flag_train_gpu and torch.cuda.device_count() == 1: model.cuda() print('Using single-gpu training.') logging.info('set optimizer') # Set optimizers if optimizer == "sgd": optimizer_model = torch.optim.SGD(model.parameters(), lr=learning_rate) elif optimizer == "adagrad": optimizer_model = torch.optim.Adagrad(model.parameters(), lr=learning_rate) elif optimizer == "rmsprop": optimizer_model = torch.optim.RMSprop(model.parameters(), lr=learning_rate) elif optimizer == "adam": optimizer_model = torch.optim.Adam(model.parameters(), lr=learning_rate) # Optionally resume from a checkpoint if resume_path: if os.path.isfile(resume_path): print("\nLoading checkpoint {} ...".format(resume_path)) checkpoint = torch.load(resume_path) start_epoch = checkpoint['epoch'] # In order to load state dict for optimizers correctly, model has to be loaded to gpu first if flag_train_multi_gpu: model.module.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint['model_state_dict']) optimizer_model.load_state_dict( checkpoint['optimizer_model_state_dict']) print( "\nCheckpoint loaded: start epoch from checkpoint = {}\nRunning for {} epochs.\n" .format(start_epoch, epochs - start_epoch)) else: print( "WARNING: No checkpoint found at {}!\nTraining from scratch.". format(resume_path)) # Start Training loop print( "\nTraining using triplet loss on {} triplets starting for {} epochs:\n" .format(num_triplets_train, epochs - start_epoch)) total_time_start = time.time() start_epoch = start_epoch end_epoch = start_epoch + epochs l2_distance = PairwiseDistance(2).cuda() BATCH_NUM = len(train_dataloader.dataset) / batch_size for epoch in range(start_epoch, end_epoch): epoch_time_start = time.time() #flag_validate_apd = (epoch + 1) % lfw_validation_epoch_interval == 0 or (epoch + 1) % epochs == 0 flag_validate_apd = True triplet_loss_sum = 0 num_valid_training_triplets = 0 # Training pass model.train() #progress_bar = enumerate(tqdm(train_dataloader)) progress_bar = enumerate(train_dataloader) for batch_idx, (batch_sample) in progress_bar: #break logging.info("epoch:{}/{} batch_idx:{}/{}".format( epoch, end_epoch, batch_idx, BATCH_NUM)) #print('batch_idx',batch_idx) anc_img = batch_sample['anc_img'].cuda() pos_img = batch_sample['pos_img'].cuda() neg_img = batch_sample['neg_img'].cuda() # Forward pass - compute embeddings anc_embedding, pos_embedding, neg_embedding = model( anc_img), model(pos_img), model(neg_img) # Forward pass - choose hard negatives only for training pos_dist = l2_distance.forward(anc_embedding, pos_embedding) neg_dist = l2_distance.forward(anc_embedding, neg_embedding) all = (neg_dist - pos_dist < margin).cpu().numpy().flatten() hard_triplets = np.where(all == 1) if len(hard_triplets[0]) == 0: continue anc_hard_embedding = anc_embedding[hard_triplets].cuda() pos_hard_embedding = pos_embedding[hard_triplets].cuda() neg_hard_embedding = neg_embedding[hard_triplets].cuda() # Calculate triplet loss triplet_loss = TripletLoss(margin=margin).forward( anchor=anc_hard_embedding, positive=pos_hard_embedding, negative=neg_hard_embedding).cuda() # Calculating loss triplet_loss_sum += triplet_loss.item() num_valid_training_triplets += len(anc_hard_embedding) # Backward pass optimizer_model.zero_grad() triplet_loss.backward() optimizer_model.step() #if batch_idx == 20: # break # Model only trains on hard negative triplets avg_triplet_loss = 0 if ( num_valid_training_triplets == 0) else triplet_loss_sum / num_valid_training_triplets epoch_time_end = time.time() # Print training statistics and add to log print( 'Epoch {}:\tAverage Triplet Loss: {:.4f}\tEpoch Time: {:.3f} hours\tNumber of valid training triplets in epoch: {}' .format(epoch + 1, avg_triplet_loss, (epoch_time_end - epoch_time_start) / 3600, num_valid_training_triplets)) with open('logs/{}_log_triplet.txt'.format(model_architecture), 'a') as f: val_list = [ epoch + 1, avg_triplet_loss, num_valid_training_triplets ] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') try: # Plot Triplet losses plot plot_triplet_losses( log_dir="logs/{}_log_triplet.txt".format(model_architecture), epochs=epochs, figure_name="plots/triplet_losses_{}.png".format( model_architecture)) except Exception as e: print(e) # Evaluation pass on LFW dataset if flag_validate_apd: model.eval() with torch.no_grad(): distances, labels = [], [] print("Validating on APD! ...") progress_bar = enumerate(tqdm(apd_dataloader)) for batch_index, (data_a, data_b, label) in progress_bar: data_a, data_b, label = data_a.cuda(), data_b.cuda( ), label.cuda() #print('data_a', data_a.shape) #print('data_b', data_b.shape) #print('label', label) output_a, output_b = model(data_a), model(data_b) #print('output_a',output_a) #print('output_b',output_b) distance = l2_distance.forward( output_a, output_b) # Euclidean distance #print('distance',distance) distances.append(distance.cpu().detach().numpy()) labels.append(label.cpu().detach().numpy()) #if batch_index == 20: # break labels = np.array( [sublabel for label in labels for sublabel in label]) distances = np.array([ subdist for distance in distances for subdist in distance ]) print('len(labels)', len(labels)) print('len(distances)', len(distances)) true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \ tar, far = evaluate_lfw( distances=distances, labels=labels ) # Print statistics and add to log print( "Accuracy on APD: {:.4f}+-{:.4f}\tPrecision {:.4f}+-{:.4f}\tRecall {:.4f}+-{:.4f}\tROC Area Under Curve: {:.4f}\tBest distance threshold: {:.2f}+-{:.2f}\tTAR: {:.4f}+-{:.4f} @ FAR: {:.4f}" .format(np.mean(accuracy), np.std(accuracy), np.mean(precision), np.std(precision), np.mean(recall), np.std(recall), roc_auc, np.mean(best_distances), np.std(best_distances), np.mean(tar), np.std(tar), np.mean(far))) with open( 'logs/lfw_{}_log_triplet.txt'.format( model_architecture), 'a') as f: val_list = [ epoch + 1, np.mean(accuracy), np.std(accuracy), np.mean(precision), np.std(precision), np.mean(recall), np.std(recall), roc_auc, np.mean(best_distances), np.std(best_distances), np.mean(tar) ] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') try: # Plot ROC curve plot_roc_lfw( false_positive_rate=false_positive_rate, true_positive_rate=true_positive_rate, figure_name= "plots/roc_plots_triplet/roc_{}_epoch_{}_triplet.png". format(model_architecture, epoch + 1), epochNum=epoch) # Plot LFW accuracies plot plot_accuracy_lfw( log_dir="logs/lfw_{}_log_triplet.txt".format( model_architecture), epochs=epochs, figure_name="plots/apd_accuracies_{}_triplet.png".format( model_architecture)) except Exception as e: print(e) # Save model checkpoint state = { 'epoch': epoch + 1, 'embedding_dimension': embedding_dimension, 'batch_size_training': batch_size, 'model_state_dict': model.state_dict(), 'model_architecture': model_architecture, 'optimizer_model_state_dict': optimizer_model.state_dict() } # For storing data parallel model's state dictionary without 'module' parameter if flag_train_multi_gpu: state['model_state_dict'] = model.module.state_dict() # For storing best euclidean distance threshold during LFW validation if flag_validate_apd: state['best_distance_threshold'] = np.mean(best_distances) # Save model checkpoint torch.save( state, 'Model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format( model_architecture, epoch + 1)) # Training loop end total_time_end = time.time() total_time_elapsed = total_time_end - total_time_start print("\nTraining finished: total time elapsed: {:.2f} hours.".format( total_time_elapsed / 3600))
def main(): lfw_dataroot = args.lfw model_path = args.model_path far_target = args.far_target checkpoint = torch.load(model_path) model = Resnet18Triplet( embedding_dimension=checkpoint['embedding_dimension']) model.load_state_dict(checkpoint['model_state_dict']) flag_gpu_available = torch.cuda.is_available() if flag_gpu_available: device = torch.device("cuda") else: device = torch.device("cpu") lfw_transforms = transforms.Compose([ transforms.Resize(size=224), transforms.ToTensor(), transforms.Normalize(mean=[0.6068, 0.4517, 0.3800], std=[0.2492, 0.2173, 0.2082]) ]) lfw_dataloader = torch.utils.data.DataLoader(dataset=LFWDataset( dir=lfw_dataroot, pairs_path='../datasets/LFW_pairs.txt', transform=lfw_transforms), batch_size=256, num_workers=2, shuffle=False) model.to(device) model = model.eval() with torch.no_grad(): l2_distance = PairwiseDistance(p=2) distances, labels = [], [] progress_bar = enumerate(tqdm(lfw_dataloader)) for batch_index, (data_a, data_b, label) in progress_bar: data_a = data_a.cuda() data_b = data_b.cuda() output_a, output_b = model(data_a), model(data_b) distance = l2_distance.forward(output_a, output_b) # Euclidean distance distances.append(distance.cpu().detach().numpy()) labels.append(label.cpu().detach().numpy()) labels = np.array([sublabel for label in labels for sublabel in label]) distances = np.array( [subdist for distance in distances for subdist in distance]) _, _, _, _, _, _, _, tar, far = evaluate_lfw(distances=distances, labels=labels, far_target=far_target) print("TAR: {:.4f}+-{:.4f} @ FAR: {:.4f}".format( np.mean(tar), np.std(tar), np.mean(far)))
def main(): lfw_dataroot = args.lfw model_path = args.model_path far_target = args.far_target batch_size = args.batch_size flag_gpu_available = torch.cuda.is_available() if flag_gpu_available: device = torch.device("cuda") print('Using GPU') else: device = torch.device("cpu") print('Using CPU') checkpoint = torch.load(model_path, map_location=device) model = Resnet18Triplet( embedding_dimension=checkpoint['embedding_dimension']) model.load_state_dict(checkpoint['model_state_dict']) # desiredFaceWidth, height and desiredLeftEye work together to produce centered aligned faces # desiredLeftEye (x,y) says how the face moves in the heightxwidth window. bigger y moves the face down # bigger x zooms-out the face and severe values migh flip the face upside down, you wan it in range [0.2;0.36] desiredFaceHeight = 448 # set non-equal width, height so the tf_resize stretch the faces, select bigger than 224 values so we don't lose too much of the image quality desiredFaceWidth = 352 desiredLeftEye = (0.28, 0.35) tf_align = transform_align( landmark_predictor_weight_path=landmark_predictor_weights, face_detector_path=face_detector_weights_path, desiredFaceWidth=desiredFaceWidth, desiredFaceHeight=desiredFaceHeight, desiredLeftEye=desiredLeftEye) lfw_transforms = transforms.Compose([ tf_align, transforms.Resize(size=(140, 140)), transforms.ToTensor(), transforms.Normalize(mean=[0.6071, 0.4609, 0.3944], std=[0.2457, 0.2175, 0.2129]) ]) lfw_dataloader = torch.utils.data.DataLoader( dataset=LFWDataset(dir=lfw_dataroot, pairs_path='../datasets/LFW_pairs.txt', transform=lfw_transforms), batch_size= batch_size, # default = 256; 160 - allows running under 2GB VRAM num_workers=2, shuffle=False) model.to(device) model = model.eval() with torch.no_grad(): l2_distance = PairwiseDistance(p=2) distances, labels = [], [] progress_bar = enumerate(tqdm(lfw_dataloader)) for batch_index, (data_a, data_b, label) in progress_bar: data_a = data_a.to(device) # data_a = data_a.cuda() data_b = data_b.to(device) # data_b = data_b.cuda() output_a, output_b = model(data_a), model(data_b) distance = l2_distance.forward(output_a, output_b) # Euclidean distance distances.append(distance.cpu().detach().numpy()) labels.append(label.cpu().detach().numpy()) labels = np.array([sublabel for label in labels for sublabel in label]) distances = np.array( [subdist for distance in distances for subdist in distance]) _, _, _, _, _, _, _, tar, far = evaluate_lfw(distances=distances, labels=labels, far_target=far_target) print("TAR: {:.4f}+-{:.4f} @ FAR: {:.4f}".format( np.mean(tar), np.std(tar), np.mean(far)))
print('不存在预训练模型!') l2_distance = PairwiseDistance(2) with torch.no_grad(): # 不传梯度了 distances, labels = [], [] progress_bar = enumerate(tqdm(test_dataloader)) for batch_index, (data_a, data_b, label) in progress_bar: #for batch_index, (data_a, data_b, label) in enumerate(test_dataloader): # data_a, data_b, label这仨是一批的矩阵 data_a = data_a.to(device) data_b = data_b.to(device) label = label.to(device) output_a, output_b = model(data_a), model(data_b) output_a = torch.div(output_a, torch.norm(output_a)) output_b = torch.div(output_b, torch.norm(output_b)) distance = l2_distance.forward(output_a, output_b) # 列表里套矩阵 labels.append(label.cpu().detach().numpy()) distances.append(distance.cpu().detach().numpy()) #if batch_index >=3: # break print("get all image's distance done") labels = np.array([sublabel for label in labels for sublabel in label]) distances = np.array( [subdist for distance in distances for subdist in distance]) true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \ tar, far = evaluate_lfw( distances=distances, labels=labels, epoch='',
def main(): dataroot = args.dataroot lfw_dataroot = args.lfw dataset_csv = args.dataset_csv epochs = args.epochs iterations_per_epoch = args.iterations_per_epoch model_architecture = args.model_architecture pretrained = args.pretrained embedding_dimension = args.embedding_dimension num_human_identities_per_batch = args.num_human_identities_per_batch batch_size = args.batch_size lfw_batch_size = args.lfw_batch_size num_generate_triplets_processes = args.num_generate_triplets_processes resume_path = args.resume_path num_workers = args.num_workers optimizer = args.optimizer learning_rate = args.learning_rate margin = args.margin image_size = args.image_size use_semihard_negatives = args.use_semihard_negatives training_triplets_path = args.training_triplets_path flag_training_triplets_path = False start_epoch = 0 if training_triplets_path is not None: flag_training_triplets_path = True # Load triplets file for the first training epoch # Define image data pre-processing transforms # ToTensor() normalizes pixel values between [0, 1] # Normalize(mean=[0.6068, 0.4517, 0.3800], std=[0.2492, 0.2173, 0.2082]) normalizes pixel values to be mean # of zero and standard deviation of 1 according to the calculated VGGFace2 with tightly-cropped faces # dataset RGB channels' mean and std values by calculate_vggface2_rgb_mean_std.py in 'datasets' folder. data_transforms = transforms.Compose([ transforms.Resize(size=image_size), transforms.RandomHorizontalFlip(), transforms.RandomRotation(degrees=5), transforms.ToTensor(), transforms.Normalize(mean=[0.6068, 0.4517, 0.3800], std=[0.2492, 0.2173, 0.2082]) ]) lfw_transforms = transforms.Compose([ transforms.Resize(size=image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.6068, 0.4517, 0.3800], std=[0.2492, 0.2173, 0.2082]) ]) lfw_dataloader = torch.utils.data.DataLoader(dataset=LFWDataset( dir=lfw_dataroot, pairs_path='datasets/LFW_pairs.txt', transform=lfw_transforms), batch_size=lfw_batch_size, num_workers=num_workers, shuffle=False) # Instantiate model model = set_model_architecture(model_architecture=model_architecture, pretrained=pretrained, embedding_dimension=embedding_dimension) # Load model to GPU or multiple GPUs if available model, flag_train_multi_gpu = set_model_gpu_mode(model) # Set optimizer optimizer_model = set_optimizer(optimizer=optimizer, model=model, learning_rate=learning_rate) # Resume from a model checkpoint if resume_path: if os.path.isfile(resume_path): print("Loading checkpoint {} ...".format(resume_path)) checkpoint = torch.load(resume_path) start_epoch = checkpoint['epoch'] + 1 optimizer_model.load_state_dict( checkpoint['optimizer_model_state_dict']) # In order to load state dict for optimizers correctly, model has to be loaded to gpu first if flag_train_multi_gpu: model.module.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint['model_state_dict']) print("Checkpoint loaded: start epoch from checkpoint = {}".format( start_epoch)) else: print( "WARNING: No checkpoint found at {}!\nTraining from scratch.". format(resume_path)) if use_semihard_negatives: print("Using Semi-Hard negative triplet selection!") else: print("Using Hard negative triplet selection!") start_epoch = start_epoch print("Training using triplet loss starting for {} epochs:\n".format( epochs - start_epoch)) for epoch in range(start_epoch, epochs): num_valid_training_triplets = 0 l2_distance = PairwiseDistance(p=2) _training_triplets_path = None if flag_training_triplets_path: _training_triplets_path = training_triplets_path flag_training_triplets_path = False # Only load triplets file for the first epoch # Re-instantiate training dataloader to generate a triplet list for this training epoch train_dataloader = torch.utils.data.DataLoader( dataset=TripletFaceDataset( root_dir=dataroot, csv_name=dataset_csv, num_triplets=iterations_per_epoch * batch_size, num_generate_triplets_processes=num_generate_triplets_processes, num_human_identities_per_batch=num_human_identities_per_batch, triplet_batch_size=batch_size, epoch=epoch, training_triplets_path=_training_triplets_path, transform=data_transforms), batch_size=batch_size, num_workers=num_workers, shuffle= False # Shuffling for triplets with set amount of human identities per batch is not required ) # Training pass model.train() progress_bar = enumerate(tqdm(train_dataloader)) for batch_idx, (batch_sample) in progress_bar: # Forward pass - compute embeddings anc_imgs = batch_sample['anc_img'] pos_imgs = batch_sample['pos_img'] neg_imgs = batch_sample['neg_img'] # Concatenate the input images into one tensor because doing multiple forward passes would create # weird GPU memory allocation behaviours later on during training which would cause GPU Out of Memory # issues all_imgs = torch.cat( (anc_imgs, pos_imgs, neg_imgs)) # Must be a tuple of Torch Tensors anc_embeddings, pos_embeddings, neg_embeddings, model = forward_pass( imgs=all_imgs, model=model, batch_size=batch_size) pos_dists = l2_distance.forward(anc_embeddings, pos_embeddings) neg_dists = l2_distance.forward(anc_embeddings, neg_embeddings) if use_semihard_negatives: # Semi-Hard Negative triplet selection # (negative_distance - positive_distance < margin) AND (positive_distance < negative_distance) # Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L295 first_condition = (neg_dists - pos_dists < margin).cpu().numpy().flatten() second_condition = (pos_dists < neg_dists).cpu().numpy().flatten() all = (np.logical_and(first_condition, second_condition)) valid_triplets = np.where(all == 1) else: # Hard Negative triplet selection # (negative_distance - positive_distance < margin) # Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L296 all = (neg_dists - pos_dists < margin).cpu().numpy().flatten() valid_triplets = np.where(all == 1) anc_valid_embeddings = anc_embeddings[valid_triplets] pos_valid_embeddings = pos_embeddings[valid_triplets] neg_valid_embeddings = neg_embeddings[valid_triplets] del anc_embeddings, pos_embeddings, neg_embeddings, pos_dists, neg_dists gc.collect() # Calculate triplet loss triplet_loss = TripletLoss(margin=margin).forward( anchor=anc_valid_embeddings, positive=pos_valid_embeddings, negative=neg_valid_embeddings) # Calculating number of triplets that met the triplet selection method during the epoch num_valid_training_triplets += len(anc_valid_embeddings) # Backward pass optimizer_model.zero_grad() triplet_loss.backward() optimizer_model.step() # Clear some memory at end of training iteration del triplet_loss, anc_valid_embeddings, pos_valid_embeddings, neg_valid_embeddings gc.collect() # Print training statistics for epoch and add to log print( 'Epoch {}:\tNumber of valid training triplets in epoch: {}'.format( epoch, num_valid_training_triplets)) with open('logs/{}_log_triplet.txt'.format(model_architecture), 'a') as f: val_list = [epoch, num_valid_training_triplets] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') # Evaluation pass on LFW dataset best_distances = validate_lfw(model=model, lfw_dataloader=lfw_dataloader, model_architecture=model_architecture, epoch=epoch, epochs=epochs) # Save model checkpoint state = { 'epoch': epoch, 'embedding_dimension': embedding_dimension, 'batch_size_training': batch_size, 'model_state_dict': model.state_dict(), 'model_architecture': model_architecture, 'optimizer_model_state_dict': optimizer_model.state_dict(), 'best_distance_threshold': np.mean(best_distances) } # For storing data parallel model's state dictionary without 'module' parameter if flag_train_multi_gpu: state['model_state_dict'] = model.module.state_dict() # Save model checkpoint torch.save( state, 'model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format( model_architecture, epoch))
for epoch in range(0,5): triplet_loss_sum = 0 num_valid_training_triplets = 0 progress_bar = enumerate(tqdm(train_dataloader)) for batch_idx, (batch_sample) in progress_bar: anc_image = batch_sample['anc_img'].view(-1, 3, 220, 220).cuda() pos_image = batch_sample['pos_img'].view(-1, 3, 220, 220).cuda() neg_image = batch_sample['neg_img'].view(-1, 3, 220, 220).cuda() anc_embedding = net(anc_image) pos_embedding = net(pos_image) neg_embedding = net(neg_image) pos_dist = l2_distance.forward(anc_embedding, pos_embedding) neg_dist = l2_distance.forward(anc_embedding, neg_embedding) print(pos_dist) print(neg_dist) print('\n') allcd = (neg_dist - pos_dist < margin).cpu().numpy().flatten() hard_triplets = np.where(allcd == 1) anc_hard_embedding = anc_embedding[hard_triplets].cuda() pos_hard_embedding = pos_embedding[hard_triplets].cuda() neg_hard_embedding = neg_embedding[hard_triplets].cuda() triplet_loss = TripletLoss(margin=margin).forward( anchor=anc_hard_embedding, positive=pos_hard_embedding, negative=neg_hard_embedding
def validate_lfw(model, lfw_dataloader, model_architecture, epoch, epochs): model.eval() with torch.no_grad(): l2_distance = PairwiseDistance(p=2) distances, labels = [], [] print("Validating on LFW! ...") progress_bar = enumerate(tqdm(lfw_dataloader)) for batch_index, (data_a, data_b, label) in progress_bar: data_a = data_a.cuda() data_b = data_b.cuda() output_a, output_b = model(data_a), model(data_b) distance = l2_distance.forward(output_a, output_b) # Euclidean distance distances.append(distance.cpu().detach().numpy()) labels.append(label.cpu().detach().numpy()) labels = np.array([sublabel for label in labels for sublabel in label]) distances = np.array( [subdist for distance in distances for subdist in distance]) true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \ tar, far = evaluate_lfw( distances=distances, labels=labels, far_target=1e-3 ) # Print statistics and add to log print( "Accuracy on LFW: {:.4f}+-{:.4f}\tPrecision {:.4f}+-{:.4f}\tRecall {:.4f}+-{:.4f}\t" "ROC Area Under Curve: {:.4f}\tBest distance threshold: {:.2f}+-{:.2f}\t" "TAR: {:.4f}+-{:.4f} @ FAR: {:.4f}".format(np.mean(accuracy), np.std(accuracy), np.mean(precision), np.std(precision), np.mean(recall), np.std(recall), roc_auc, np.mean(best_distances), np.std(best_distances), np.mean(tar), np.std(tar), np.mean(far))) with open('logs/lfw_{}_log_triplet.txt'.format(model_architecture), 'a') as f: val_list = [ epoch, np.mean(accuracy), np.std(accuracy), np.mean(precision), np.std(precision), np.mean(recall), np.std(recall), roc_auc, np.mean(best_distances), np.std(best_distances), np.mean(tar) ] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') try: # Plot ROC curve plot_roc_lfw( false_positive_rate=false_positive_rate, true_positive_rate=true_positive_rate, figure_name="plots/roc_plots/roc_{}_epoch_{}_triplet.png".format( model_architecture, epoch)) # Plot LFW accuracies plot plot_accuracy_lfw( log_dir="logs/lfw_{}_log_triplet.txt".format(model_architecture), epochs=epochs, figure_name="plots/lfw_accuracies_{}_triplet.png".format( model_architecture)) except Exception as e: print(e) return best_distances
def main(): dataroot = args.dataroot lfw_dataroot = args.lfw dataset_csv = args.dataset_csv lfw_batch_size = args.lfw_batch_size lfw_validation_epoch_interval = args.lfw_validation_epoch_interval model_architecture = args.model epochs = args.epochs num_triplets_train = args.num_triplets_train resume_path = args.resume_path batch_size = args.batch_size num_workers = args.num_workers embedding_dimension = args.embedding_dim pretrained = args.pretrained learning_rate = args.lr margin = args.margin start_epoch = 0 # Define image data pre-processing transforms # ToTensor() normalizes pixel values between [0, 1] # Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalizes pixel values between [-1, 1] # Size 182x182 RGB image -> Center crop size 160x160 RGB image for more model generalization data_transforms = transforms.Compose([ transforms.RandomCrop(size=160), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Size 160x160 RGB image lfw_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Set dataloaders train_dataloader = torch.utils.data.DataLoader(TripletFaceDataset( root_dir=dataroot, csv_name=dataset_csv, num_triplets=num_triplets_train, transform=data_transforms), batch_size=batch_size, num_workers=num_workers, shuffle=False) lfw_dataloader = torch.utils.data.DataLoader(LFWDataset( dir=lfw_dataroot, pairs_path='datasets/LFW_pairs.txt', transform=lfw_transforms), batch_size=lfw_batch_size, num_workers=num_workers, shuffle=False) # Instantiate model if model_architecture == "resnet34": model = Resnet34Triplet(embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "resnet50": model = Resnet50Triplet(embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "resnet101": model = Resnet101Triplet(embedding_dimension=embedding_dimension, pretrained=pretrained) print("Using {} model architecture.".format(model_architecture)) # Load model to GPU or multiple GPUs if available flag_train_gpu = torch.cuda.is_available() flag_train_multi_gpu = False if flag_train_gpu and torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.cuda() flag_train_multi_gpu = True print('Using multi-gpu training.') elif flag_train_gpu and torch.cuda.device_count() == 1: model.cuda() print('Using single-gpu training.') # Set optimizers optimizer_model = torch.optim.SGD(model.parameters(), lr=learning_rate) # Optionally resume from a checkpoint if resume_path: if os.path.isfile(resume_path): print("\nLoading checkpoint {} ...".format(resume_path)) checkpoint = torch.load(resume_path) start_epoch = checkpoint['epoch'] # In order to load state dict for optimizers correctly, model has to be loaded to gpu first if flag_train_multi_gpu: model.module.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint['model_state_dict']) optimizer_model.load_state_dict( checkpoint['optimizer_model_state_dict']) print( "\nCheckpoint loaded: start epoch from checkpoint = {}\nRunning for {} epochs.\n" .format(start_epoch, epochs - start_epoch)) else: print( "WARNING: No checkpoint found at {}!\nTraining from scratch.". format(resume_path)) # Training loop print( "\nTraining using triplet loss on {} triplets starting for {} epochs:\n" .format(num_triplets_train, epochs - start_epoch)) total_time_start = time.time() start_epoch = start_epoch end_epoch = start_epoch + epochs l2_distance = PairwiseDistance(2).cuda() for epoch in range(start_epoch, end_epoch): flag_validate_lfw = (epoch + 1) % lfw_validation_epoch_interval == 0 or ( epoch + 1) % epochs == 0 triplet_loss_sum = 0 epoch_time_start = time.time() # Training the model model.train() progress_bar = tqdm(enumerate(train_dataloader)) for batch_idx, (batch_sample) in progress_bar: anc_img = batch_sample['anc_img'].cuda() pos_img = batch_sample['pos_img'].cuda() neg_img = batch_sample['neg_img'].cuda() # Forward pass - compute embeddings anc_embedding, pos_embedding, neg_embedding = model( anc_img), model(pos_img), model(neg_img) # Forward pass - choose hard negatives only for training pos_dist = l2_distance.forward(anc_embedding, pos_embedding) neg_dist = l2_distance.forward(anc_embedding, neg_embedding) all = (neg_dist - pos_dist < margin).cpu().numpy().flatten() hard_triplets = np.where(all == 1) if len(hard_triplets[0]) == 0: continue anc_hard_embedding = anc_embedding[hard_triplets].cuda() pos_hard_embedding = pos_embedding[hard_triplets].cuda() neg_hard_embedding = neg_embedding[hard_triplets].cuda() # Calculate triplet loss triplet_loss = TripletLoss(margin=margin).forward( anchor=anc_hard_embedding, positive=pos_hard_embedding, negative=neg_hard_embedding).cuda() triplet_loss_sum += triplet_loss.item() # Backward pass optimizer_model.zero_grad() triplet_loss.backward() optimizer_model.step() avg_triplet_loss = triplet_loss_sum / len(train_dataloader.dataset) epoch_time_end = time.time() # Print training and validation statistics and add to log print('Epoch {}:\tTriplet Loss: {:.4f}\tEpoch Time: {:.2f} minutes'. format(epoch + 1, avg_triplet_loss, (epoch_time_end - epoch_time_start) / 60)) with open('logs/{}_log_triplet.txt'.format(model_architecture), 'a') as f: val_list = [epoch + 1, avg_triplet_loss] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') # Validating on LFW dataset using KFold based on Euclidean distance metric if flag_validate_lfw: model.eval() with torch.no_grad(): distances, labels = [], [] print("Validating on LFW! ...") progress_bar = tqdm(enumerate(lfw_dataloader)) for batch_index, (data_a, data_b, label) in progress_bar: data_a, data_b, label = data_a.cuda(), data_b.cuda( ), label.cuda() output_a, output_b = model(data_a), model(data_b) distance = l2_distance.forward( output_a, output_b) # Euclidean distance distances.append(distance.cpu().detach().numpy()) labels.append(label.cpu().detach().numpy()) labels = np.array( [sublabel for label in labels for sublabel in label]) distances = np.array([ subdist for distance in distances for subdist in distance ]) true_positive_rate, false_positive_rate, precision, recall, accuracy, auc, best_distance_threshold, \ tar, far = evaluate_lfw( distances=distances, labels=labels ) # Print statistics and add to log print( "Accuracy on LFW: {:.4f}+-{:.4f}\tPrecision {:.4f}\tRecall {:.4f}\tArea Under Curve: {:.4f}\t" "Best distance threshold: {:.2f}\tTAR: {:.4f}+-{:.4f} @ FAR: {:.4f}" .format(np.mean(accuracy), np.std(accuracy), np.mean(precision), np.mean(recall), auc, best_distance_threshold, np.mean(tar), np.std(tar), np.mean(far))) with open( 'logs/lfw_{}_log_triplet.txt'.format( model_architecture), 'a') as f: val_list = [ epoch + 1, np.mean(accuracy), np.std(accuracy), np.mean(precision), np.mean(recall), auc, best_distance_threshold, np.mean(tar) ] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') # Plot ROC curve plot_roc_lfw( false_positive_rate, true_positive_rate, figure_name="plots/roc_plots/roc_{}_epoch_{}_triplet.png". format(model_architecture, epoch + 1)) # Save model checkpoint state = { 'epoch': epoch + 1, 'embedding_dimension': embedding_dimension, 'batch_size_training': batch_size, 'model_state_dict': model.state_dict(), 'model_architecture': model_architecture, 'optimizer_model_state_dict': optimizer_model.state_dict() } # For storing data parallel model's state dictionary without 'module' parameter if flag_train_multi_gpu: state['model_state_dict'] = model.module.state_dict() # For storing best euclidean distance threshold during LFW validation if flag_validate_lfw: state['best_distance_threshold'] = best_distance_threshold torch.save(state, 'model_{}_triplet.pt'.format(model_architecture)) # Training loop end total_time_end = time.time() total_time_elapsed = total_time_end - total_time_start print("\nTraining finished: total time elapsed: {:.2f} minutes.".format( total_time_elapsed / 60)) # Plot lfw accuracies plot print("\nPlotting plot!") plot_accuracy_lfw( log_dir="logs/lfw_{}_log_triplet.txt".format(model_architecture), epochs=epochs, figure_name="plots/lfw_accuracies_{}_triplet.png".format( model_architecture)) print("\nDone.")
pos_img = batch_sample['pos_img'].cuda() neg_img = batch_sample['neg_img'].cuda() # 取出三张mask图(batch*图) mask_anc = batch_sample['mask_anc'].cuda() mask_pos = batch_sample['mask_pos'].cuda() mask_neg = batch_sample['mask_neg'].cuda() # 模型运算 # 前向传播过程-拿模型分别跑三张图,生成embedding和loss(在训练阶段的输入是两张图,输出带loss,而验证阶段输入一张图,输出只有embedding) anc_embedding, anc_attention_loss = model((anc_img, mask_anc)) pos_embedding, pos_attention_loss = model((pos_img, mask_pos)) neg_embedding, neg_attention_loss = model((neg_img, mask_neg)) # 寻找困难样本 # 计算embedding的L2 pos_dist = l2_distance.forward(anc_embedding, pos_embedding) neg_dist = l2_distance.forward(anc_embedding, neg_embedding) # 找到满足困难样本标准的样本 all = (neg_dist - pos_dist < config['margin']).cpu().numpy().flatten() hard_triplets = np.where(all == 1) if len(hard_triplets[0]) == 0: continue # 选定困难样本——困难embedding anc_hard_embedding = anc_embedding[hard_triplets].cuda() pos_hard_embedding = pos_embedding[hard_triplets].cuda() neg_hard_embedding = neg_embedding[hard_triplets].cuda() # 选定困难样本——困难样本对应的attention loss hard_anc_attention_loss = anc_attention_loss[hard_triplets] hard_pos_attention_loss = pos_attention_loss[hard_triplets] hard_neg_attention_loss = neg_attention_loss[hard_triplets]
def train_triplet(start_epoch, end_epoch, epochs, train_dataloader, lfw_dataloader, lfw_validation_epoch_interval, model, model_architecture, optimizer_model, embedding_dimension, batch_size, margin, flag_train_multi_gpu, optimizer, learning_rate, use_semihard_negatives): for epoch in range(start_epoch, end_epoch): flag_validate_lfw = (epoch + 1) % lfw_validation_epoch_interval == 0 or ( epoch + 1) % epochs == 0 triplet_loss_sum = 0 num_valid_training_triplets = 0 l2_distance = PairwiseDistance(p=2) # Training pass model.train() progress_bar = enumerate(tqdm(train_dataloader)) for batch_idx, (batch_sample) in progress_bar: # Skip last iteration to avoid the problem of having different number of tensors while calculating # pairwise distances (sizes of tensors must be the same for pairwise distance calculation) if batch_idx + 1 == len(train_dataloader): continue # Forward pass - compute embeddings anc_imgs = batch_sample['anc_img'] pos_imgs = batch_sample['pos_img'] neg_imgs = batch_sample['neg_img'] # Concatenate the input images into one tensor because doing multiple forward passes would create # weird GPU memory allocation behaviours later on during training which would cause GPU Out of Memory # issues all_imgs = torch.cat( (anc_imgs, pos_imgs, neg_imgs)) # Must be a tuple of Torch Tensors anc_embeddings, pos_embeddings, neg_embeddings, model, optimizer_model, flag_use_cpu = forward_pass( imgs=all_imgs, model=model, optimizer_model=optimizer_model, batch_idx=batch_idx, optimizer=optimizer, learning_rate=learning_rate, batch_size=batch_size, use_cpu=False) pos_dists = l2_distance.forward(anc_embeddings, pos_embeddings) neg_dists = l2_distance.forward(anc_embeddings, neg_embeddings) if use_semihard_negatives: # Semi-Hard Negative triplet selection # (negative_distance - positive_distance < margin) AND (positive_distance < negative_distance) # Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L295 first_condition = (neg_dists - pos_dists < margin).cpu().numpy().flatten() second_condition = (pos_dists < neg_dists).cpu().numpy().flatten() all = (np.logical_and(first_condition, second_condition)) semihard_negative_triplets = np.where(all == 1) if len(semihard_negative_triplets[0]) == 0: continue anc_valid_embeddings = anc_embeddings[ semihard_negative_triplets] pos_valid_embeddings = pos_embeddings[ semihard_negative_triplets] neg_valid_embeddings = neg_embeddings[ semihard_negative_triplets] del anc_embeddings, pos_embeddings, neg_embeddings, pos_dists, neg_dists gc.collect() else: # Hard Negative triplet selection # (negative_distance - positive_distance < margin) # Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L296 all = (neg_dists - pos_dists < margin).cpu().numpy().flatten() hard_negative_triplets = np.where(all == 1) if len(hard_negative_triplets[0]) == 0: continue anc_valid_embeddings = anc_embeddings[hard_negative_triplets] pos_valid_embeddings = pos_embeddings[hard_negative_triplets] neg_valid_embeddings = neg_embeddings[hard_negative_triplets] del anc_embeddings, pos_embeddings, neg_embeddings, pos_dists, neg_dists gc.collect() # Calculate triplet loss triplet_loss = TripletLoss(margin=margin).forward( anchor=anc_valid_embeddings, positive=pos_valid_embeddings, negative=neg_valid_embeddings) # Calculating loss and number of triplets that met the triplet selection method during the epoch triplet_loss_sum += triplet_loss.item() num_valid_training_triplets += len(anc_valid_embeddings) # Backward pass optimizer_model.zero_grad() triplet_loss.backward() optimizer_model.step() # Load model and optimizer back to GPU if CUDA Out of Memory Exception occurred and model and optimizer # were switched to CPU if flag_use_cpu: # According to https://github.com/pytorch/pytorch/issues/2830#issuecomment-336183179 # In order for the optimizer to keep training the model after changing to a different type or device, # optimizers have to be recreated, 'load_state_dict' can be used to restore the state from a # previous copy. As such, the optimizer state dict will be saved first and then reloaded when # the model's device is changed. torch.cuda.empty_cache() # Print number of valid triplets (troubleshooting out of memory causes) print("Number of valid triplets during OOM iteration = {}". format(len(anc_valid_embeddings))) torch.save( optimizer_model.state_dict(), 'model_training_checkpoints/out_of_memory_optimizer_checkpoint/optimizer_checkpoint.pt' ) # Load back to CUDA model.cuda() optimizer_model = set_optimizer(optimizer=optimizer, model=model, learning_rate=learning_rate) optimizer_model.load_state_dict( torch.load( 'model_training_checkpoints/out_of_memory_optimizer_checkpoint/optimizer_checkpoint.pt' )) # Copied from https://github.com/pytorch/pytorch/issues/2830#issuecomment-336194949 # No optimizer.cuda() available, this is the way to make an optimizer loaded with cpu tensors load # with cuda tensors. for state in optimizer_model.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() # Clear some memory at end of training iteration del triplet_loss, anc_valid_embeddings, pos_valid_embeddings, neg_valid_embeddings gc.collect() # Model only trains on triplets that fit the triplet selection method avg_triplet_loss = 0 if ( num_valid_training_triplets == 0) else triplet_loss_sum / num_valid_training_triplets # Print training statistics and add to log print( 'Epoch {}:\tAverage Triplet Loss: {:.4f}\tNumber of valid training triplets in epoch: {}' .format(epoch + 1, avg_triplet_loss, num_valid_training_triplets)) with open('logs/{}_log_triplet.txt'.format(model_architecture), 'a') as f: val_list = [ epoch + 1, avg_triplet_loss, num_valid_training_triplets ] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') try: # Plot Triplet losses plot plot_triplet_losses( log_dir="logs/{}_log_triplet.txt".format(model_architecture), epochs=epochs, figure_name="plots/triplet_losses_{}.png".format( model_architecture)) except Exception as e: print(e) # Evaluation pass on LFW dataset if flag_validate_lfw: best_distances = validate_lfw( model=model, lfw_dataloader=lfw_dataloader, model_architecture=model_architecture, epoch=epoch, epochs=epochs) # Save model checkpoint state = { 'epoch': epoch + 1, 'embedding_dimension': embedding_dimension, 'batch_size_training': batch_size, 'model_state_dict': model.state_dict(), 'model_architecture': model_architecture, 'optimizer_model_state_dict': optimizer_model.state_dict() } # For storing data parallel model's state dictionary without 'module' parameter if flag_train_multi_gpu: state['model_state_dict'] = model.module.state_dict() # For storing best euclidean distance threshold during LFW validation if flag_validate_lfw: state['best_distance_threshold'] = np.mean(best_distances) # Save model checkpoint torch.save( state, 'model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format( model_architecture, epoch + 1))
def main(): dataroot = args.dataroot apd_dataroot = args.apd apd_batch_size = args.apd_batch_size apd_validation_epoch_interval = args.apd_validation_epoch_interval model_architecture = args.model epochs = args.epochs resume_path = args.resume_path batch_size = args.batch_size num_workers = args.num_workers validation_dataset_split_ratio = args.valid_split embedding_dimension = args.embedding_dim pretrained = args.pretrained optimizer = args.optimizer learning_rate = args.lr learning_rate_center_loss = args.center_loss_lr center_loss_weight = args.center_loss_weight start_epoch = 0 # Define image data pre-processing transforms # ToTensor() normalizes pixel values between [0, 1] # Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalizes pixel values between [-1, 1] # Size 182x182 RGB image -> Center crop size 160x160 RGB image for more model generalization data_transforms = transforms.Compose([ #transforms.RandomCrop(size=50), transforms.Resize(size=(160, 160)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Size 160x160 RGB image apd_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Load the dataset dataset = torchvision.datasets.ImageFolder(root=dataroot, transform=data_transforms) # Subset the dataset into training and validation datasets num_classes = len(dataset.classes) print("\nNumber of classes in dataset: {}".format(num_classes)) num_validation = int(num_classes * validation_dataset_split_ratio) num_train = num_classes - num_validation indices = list(range(num_classes)) np.random.seed(420) np.random.shuffle(indices) train_indices = indices[:num_train] validation_indices = indices[num_train:] #train_dataset = Subset(dataset=dataset, indices=train_indices) #validation_dataset = Subset(dataset=dataset, indices=validation_indices) #print("Number of classes in training dataset: {}".format(len(train_dataset))) #print("Number of classes in validation dataset: {}".format(len(validation_dataset))) # Define the dataloaders train_dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) """ validation_dataloader = DataLoader( dataset=validation_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False ) """ logging.info('prepare APD dataset') apd_dataroot = '../APD' apdDataset = APDDataset( directory=apd_dataroot, #pairs_path='negative_pairs.txt', #pairs_path='positive_pairs.txt', pairs_path='full.txt', transform=apd_transforms) print('apdDataset', apdDataset) apd_dataloader = torch.utils.data.DataLoader(dataset=apdDataset, batch_size=apd_batch_size, num_workers=num_workers, shuffle=True) # Instantiate model if model_architecture == "resnet18": model = Resnet18Center(num_classes=num_classes, embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "resnet34": model = Resnet34Center(num_classes=num_classes, embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "resnet50": model = Resnet50Center(num_classes=num_classes, embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "resnet101": model = Resnet101Center(num_classes=num_classes, embedding_dimension=embedding_dimension, pretrained=pretrained) elif model_architecture == "inceptionresnetv2": model = InceptionResnetV2Center( num_classes=num_classes, embedding_dimension=embedding_dimension, pretrained=pretrained) print("\nUsing {} model architecture.".format(model_architecture)) # Load model to GPU or multiple GPUs if available flag_train_gpu = torch.cuda.is_available() flag_train_multi_gpu = False if flag_train_gpu and torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.cuda() flag_train_multi_gpu = True print('Using multi-gpu training.') elif flag_train_gpu and torch.cuda.device_count() == 1: model.cuda() print('Using single-gpu training.') # Set loss functions criterion_crossentropy = nn.CrossEntropyLoss().cuda() criterion_centerloss = CenterLoss(num_classes=num_classes, feat_dim=embedding_dimension).cuda() # Set optimizers if optimizer == "sgd": optimizer_model = torch.optim.SGD(model.parameters(), lr=learning_rate) optimizer_centerloss = torch.optim.SGD( criterion_centerloss.parameters(), lr=learning_rate_center_loss) elif optimizer == "adagrad": optimizer_model = torch.optim.Adagrad(model.parameters(), lr=learning_rate) optimizer_centerloss = torch.optim.Adagrad( criterion_centerloss.parameters(), lr=learning_rate_center_loss) elif optimizer == "rmsprop": optimizer_model = torch.optim.RMSprop(model.parameters(), lr=learning_rate) optimizer_centerloss = torch.optim.RMSprop( criterion_centerloss.parameters(), lr=learning_rate_center_loss) elif optimizer == "adam": optimizer_model = torch.optim.Adam(model.parameters(), lr=learning_rate) optimizer_centerloss = torch.optim.Adam( criterion_centerloss.parameters(), lr=learning_rate_center_loss) # Set learning rate decay scheduler learning_rate_scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer_model, milestones=[150, 225], gamma=0.1) # Optionally resume from a checkpoint if resume_path: if os.path.isfile(resume_path): print("\nLoading checkpoint {} ...".format(resume_path)) checkpoint = torch.load(resume_path) start_epoch = checkpoint['epoch'] # In order to load state dict for optimizers correctly, model has to be loaded to gpu first if flag_train_multi_gpu: model.module.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint['model_state_dict']) optimizer_model.load_state_dict( checkpoint['optimizer_model_state_dict']) optimizer_centerloss.load_state_dict( checkpoint['optimizer_centerloss_state_dict']) learning_rate_scheduler.load_state_dict( checkpoint['learning_rate_scheduler_state_dict']) print( "\nCheckpoint loaded: start epoch from checkpoint = {}\nRunning for {} epochs.\n" .format(start_epoch, epochs - start_epoch)) else: print( "WARNING: No checkpoint found at {}!\nTraining from scratch.". format(resume_path)) # Start Training loop print( "\nTraining using cross entropy loss with center loss starting for {} epochs:\n" .format(epochs - start_epoch)) total_time_start = time.time() start_epoch = start_epoch end_epoch = start_epoch + epochs BATCH_NUM = len(train_dataloader.dataset) / batch_size APD_BATCH_NUM = len(apd_dataloader.dataset) / apd_batch_size for epoch in range(start_epoch, end_epoch): epoch_time_start = time.time() #flag_validate_apd = (epoch + 1) % lfw_validation_epoch_interval == 0 or (epoch + 1) % epochs == 0 flag_validate_apd = True train_loss_sum = 0 validation_loss_sum = 0 # Training the model model.train() learning_rate_scheduler.step() #progress_bar = enumerate(tqdm(train_dataloader)) progress_bar = enumerate(train_dataloader) for batch_index, (data, labels) in progress_bar: #break data, labels = data.cuda(), labels.cuda() print(data.shape) #print('label', labels) # Forward pass if flag_train_multi_gpu: embedding, logits = model.module.forward_training(data) else: embedding, logits = model.forward_training(data) # Calculate losses cross_entropy_loss = criterion_crossentropy( logits.cuda(), labels.cuda()) center_loss = criterion_centerloss(embedding, labels) loss = (center_loss * center_loss_weight) + cross_entropy_loss #loss = cross_entropy_loss logging.info("epoch:{}/{} batch_idx:{}/{} loss:{}".format( epoch, end_epoch, batch_index, BATCH_NUM, loss)) # Backward pass optimizer_centerloss.zero_grad() optimizer_model.zero_grad() loss.backward() optimizer_centerloss.step() optimizer_model.step() # Remove center_loss_weight impact on the learning of center vectors #for param in criterion_centerloss.parameters(): # param.grad.data *= (1. / center_loss_weight) # Update training loss sum train_loss_sum += loss.item() * data.size(0) # Calculate average losses in epoch avg_train_loss = train_loss_sum / len(train_dataloader.dataset) """ avg_validation_loss = validation_loss_sum / len(validation_dataloader.dataset) """ # Calculate training performance statistics in epoch #classification_accuracy = correct * 100. / total #classification_error = 100. - classification_accuracy epoch_time_end = time.time() print('Epoch {}:\t Average Training Loss: {:.4f}\t'.format( epoch + 1, avg_train_loss)) with open('logs/{}_log_center.txt'.format(model_architecture), 'a') as f: val_list = [ epoch + 1, avg_train_loss #avg_validation_loss, #classification_accuracy.item(), #classification_error.item() ] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') try: # Plot plot for Cross Entropy Loss and Center Loss on training and validation sets plot_training_validation_losses_center( log_dir="logs/{}_log_center.txt".format(model_architecture), epochs=epochs, figure_name="plots/training_validation_losses_{}_center.png". format(model_architecture)) except Exception as e: print(e) # Validating on LFW dataset using KFold based on Euclidean distance metric if flag_validate_apd: model.eval() with torch.no_grad(): l2_distance = PairwiseDistance(2).cuda() distances, labels = [], [] print("Validating on APD! ...") #progress_bar = enumerate(tqdm(apd_dataloader)) progress_bar = enumerate(apd_dataloader) for batch_index, (data_a, data_b, label) in progress_bar: logging.info("epoch:{}/{} batch_idx:{}/{}".format( epoch, end_epoch, batch_index, APD_BATCH_NUM)) data_a, data_b, label = data_a.cuda(), data_b.cuda( ), label.cuda() output_a, output_b = model(data_a), model(data_b) distance = l2_distance.forward( output_a, output_b) # Euclidean distance #print(distance) #print(label) distances.append(distance.cpu().detach().numpy()) labels.append(label.cpu().detach().numpy()) labels = np.array( [sublabel for label in labels for sublabel in label]) distances = np.array([ subdist for distance in distances for subdist in distance ]) true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \ tar, far = evaluate_lfw( distances=distances, labels=labels ) # Print statistics and add to log print( "Accuracy on LFW: {:.4f}+-{:.4f}\tPrecision {:.4f}+-{:.4f}\tRecall {:.4f}+-{:.4f}\tROC Area Under Curve: {:.4f}\tBest distance threshold: {:.2f}+-{:.2f}\tTAR: {:.4f}+-{:.4f} @ FAR: {:.4f}" .format(np.mean(accuracy), np.std(accuracy), np.mean(precision), np.std(precision), np.mean(recall), np.std(recall), roc_auc, np.mean(best_distances), np.std(best_distances), np.mean(tar), np.std(tar), np.mean(far))) with open( 'logs/lfw_{}_log_center.txt'.format( model_architecture), 'a') as f: val_list = [ epoch + 1, np.mean(accuracy), np.std(accuracy), np.mean(precision), np.std(precision), np.mean(recall), np.std(recall), roc_auc, np.mean(best_distances), np.std(best_distances), np.mean(tar) ] log = '\t'.join(str(value) for value in val_list) f.writelines(log + '\n') try: # Plot ROC curve plot_roc_lfw( false_positive_rate=false_positive_rate, true_positive_rate=true_positive_rate, figure_name= "plots/roc_plots_center/roc_{}_epoch_{}_center.png".format( model_architecture, epoch + 1), epochNum=epoch) # Plot LFW accuracies plot plot_accuracy_lfw( log_dir="logs/lfw_{}_log_center.txt".format( model_architecture), epochs=epochs, figure_name="plots/lfw_accuracies_{}_center.png".format( model_architecture)) except Exception as e: print(e) # Save model checkpoint state = { 'epoch': epoch + 1, 'num_classes': num_classes, 'embedding_dimension': embedding_dimension, 'batch_size_training': batch_size, 'model_state_dict': model.state_dict(), 'model_architecture': model_architecture, 'optimizer_model_state_dict': optimizer_model.state_dict(), 'optimizer_centerloss_state_dict': optimizer_centerloss.state_dict(), 'learning_rate_scheduler_state_dict': learning_rate_scheduler.state_dict() } # For storing data parallel model's state dictionary without 'module' parameter if flag_train_multi_gpu: state['model_state_dict'] = model.module.state_dict() # For storing best euclidean distance threshold during LFW validation if flag_validate_apd: state['best_distance_threshold'] = np.mean(best_distances) # Save model checkpoint torch.save( state, 'center_checkpoints/model_{}_center_epoch_{}.pt'.format( model_architecture, epoch + 1)) # Training loop end total_time_end = time.time() total_time_elapsed = total_time_end - total_time_start print("\nTraining finished: total time elapsed: {:.2f} hours.".format( total_time_elapsed / 3600))