batch_size=hyperparams.batch_size, shuffle=False, collate_fn=WikiDataset.pad_collate) trainer = Trainer( model=baseline_model, loss_function=nn.CrossEntropyLoss( ignore_index=train_dataset.label2idx['<PAD>']), optimizer=Adam(baseline_model.parameters(), lr=1e-6, weight_decay=1e-5), label_vocab=train_dataset.label2idx ) trainer.train(train_dataset_, dev_dataset_, epochs=2) save_model_path = os.path.join(RESOURCES_PATH, f'{baseline_model.name}_model.pt') torch.save(baseline_model.state_dict(), save_model_path) # load model # model = BaselineModel(hyperparams) # model.load_state_dict(torch.load(save_model_path)) # evaluate model # test_loss, test_acc = trainer.evaluate(test_dataset) # print(f"Test set\nLoss: {test_loss:.5f}, Acc: {test_acc * 100:.5f}%") scores = compute_scores(baseline_model, dev_dataset_) per_class_precision = scores["per_class_precision"] precision_, recall_, f1score_, _ = scores['macro_precision_recall_fscore'] print(f"Macro Precision: {precision_}")
def main(): # pdb.set_trace() # Model Hyperparams random.seed(opt.random_seed) baseline = opt.baseline hidden_size = opt.hidden_size lambd = opt.lambd learning_rate = opt.learning_rate adv_learning_rate = opt.adv_learning_rate save_after_x_epochs = 10 num_classes = 39 # Determine device device = getDevice(opt.gpu_id) # Create data loaders data_loaders = load_celeba(splits=['train', 'valid'], batch_size=opt.batch_size, subset_percentage=opt.subset_percentage, \ protected_percentage = opt.protected_percentage, balance_protected=opt.balance_protected) train_data_loader = data_loaders['train'] dev_data_loader = data_loaders['valid'] # Load checkpoint checkpoint = None if opt.weights != '': checkpoint = torch.load(opt.weights, map_location=device) baseline = checkpoint['baseline'] hidden_size = checkpoint['hyp']['hidden_size'] # Create model if baseline: model = BaselineModel(hidden_size) else: model = OurModel(hidden_size) # Convert device model = model.to(device) # Loss criterion criterion = nn.BCEWithLogitsLoss() # For multi-label classification if not baseline: adversarial_criterion = nn.BCEWithLogitsLoss() # Create optimizers primary_optimizer_params = list(model.encoder.parameters()) + list( model.classifier.parameters()) primary_optimizer = torch.optim.Adam(primary_optimizer_params, lr=learning_rate) if not baseline: adversarial_optimizer_params = list(model.adv_head.parameters()) adversarial_optimizer = torch.optim.Adam(adversarial_optimizer_params, lr=adv_learning_rate) start_epoch = 0 best_acc = 0.0 save_best = False train_batch_count = len(train_data_loader) dev_batch_count = len(dev_data_loader) if checkpoint is not None: # Load model weights model.load_state_dict(checkpoint['model']) # Load metadata to resume training if opt.resume: if checkpoint['epoch']: start_epoch = checkpoint['epoch'] + 1 if checkpoint['best_acc']: best_acc = checkpoint['best_acc'] if checkpoint['hyp']['lambd']: lambd = checkpoint['hyp']['lambd'] if checkpoint['optimizers']['primary']: primary_optimizer.load_state_dict( checkpoint['optimizers']['primary']) if checkpoint['optimizers']['adversarial']: adversarial_optimizer.load_state_dict( checkpoint['optimizers']['adversarial']) # Train loop # pdb.set_trace() adversarial_loss = None for epoch in range(start_epoch, opt.num_epochs): # Set model to train mode model.train() # Initialize meters and confusion matrices mean_accuracy = AverageMeter(device=device) cm_m = None cm_f = None with tqdm(enumerate(train_data_loader), total=train_batch_count) as pbar: # progress bar for i, (images, targets, genders, protected_labels) in pbar: # Shape: torch.Size([batch_size, 3, crop_size, crop_size]) images = Variable(images.to(device)) # Shape: torch.Size([batch_size, 39]) targets = Variable(targets.to(device)) # Shape: torch.Size([batch_size]) genders = Variable(genders.to(device)) # Shape: torch.Size([batch_size]) protected_labels = Variable( protected_labels.type(torch.BoolTensor).to(device)) # Forward pass if baseline: outputs, (a, a_detached) = model(images) else: outputs, (a, a_detached) = model(images, protected_labels) targets = targets.type_as(outputs) genders = genders.type_as(outputs) # Zero out buffers # model.zero_grad() # either model or optimizer.zero_grad() is fine primary_optimizer.zero_grad() # CrossEntropyLoss is expecting: # Input: (N, C) where C = number of classes classification_loss = criterion(outputs, targets) if baseline: loss = classification_loss else: if a != None: adversarial_loss = adversarial_criterion( a, genders[protected_labels]) loss = classification_loss - lambd * adversarial_loss # Backward pass (Primary) loss.backward() primary_optimizer.step() # Zero out buffers adversarial_optimizer.zero_grad() # Calculate loss for adversarial head adversarial_loss = adversarial_criterion( a_detached, genders[protected_labels]) # Backward pass (Adversarial) adversarial_loss.backward() adversarial_optimizer.step() else: loss = classification_loss # Backward pass (Primary) loss.backward() primary_optimizer.step() # Convert genders: (batch_size, 1) -> (batch_size,) genders = genders.view(-1).bool() # Calculate accuracy train_acc, _ = calculateAccuracy(outputs, targets) # Calculate confusion matrices batch_cm_m, batch_cm_f = calculateGenderConfusionMatrices( outputs, targets, genders) if cm_m is None and cm_f is None: cm_m = batch_cm_m cm_f = batch_cm_f else: cm_m = list(cm_m) cm_f = list(cm_f) for j in range(len(cm_m)): cm_m[j] += batch_cm_m[j] cm_f[j] += batch_cm_f[j] cm_m = tuple(cm_m) cm_f = tuple(cm_f) # Update averages mean_accuracy.update(train_acc, images.size(0)) if baseline: s_train = ('%10s Loss: %.4f, Accuracy: %.4f') % ( '%g/%g' % (epoch, opt.num_epochs - 1), loss.item(), mean_accuracy.avg) else: if adversarial_loss == None: s_train = ( '%10s Classification Loss: %.4f, Total Loss: %.4f, Accuracy: %.4f' ) % ('%g/%g' % (epoch, opt.num_epochs - 1), classification_loss.item(), loss.item(), mean_accuracy.avg) else: s_train = ( '%10s Classification Loss: %.4f, Adversarial Loss: %.4f, Total Loss: %.4f, Accuracy: %.4f' ) % ('%g/%g' % (epoch, opt.num_epochs - 1), classification_loss.item(), adversarial_loss.item(), loss.item(), mean_accuracy.avg) # Calculate fairness metrics on final batch if i == train_batch_count - 1: avg_equality_gap_0, avg_equality_gap_1, _, _ = calculateEqualityGap( cm_m, cm_f) avg_parity_gap, _ = calculateParityGap(cm_m, cm_f) s_train += ( ', Equality Gap 0: %.4f, Equality Gap 1: %.4f, Parity Gap: %.4f' ) % (avg_equality_gap_0, avg_equality_gap_1, avg_parity_gap) pbar.set_description(s_train) # end batch ------------------------------------------------------------------------------------------------ # Evaluate # pdb.set_trace() model.eval() # Initialize meters, confusion matrices, and metrics mean_accuracy = AverageMeter() attr_accuracy = AverageMeter((1, num_classes), device=device) cm_m = None cm_f = None attr_equality_gap_0 = None attr_equality_gap_1 = None attr_parity_gap = None with tqdm(enumerate(dev_data_loader), total=dev_batch_count) as pbar: for i, (images, targets, genders, protected_labels) in pbar: images = Variable(images.to(device)) targets = Variable(targets.to(device)) genders = Variable(genders.to(device)) with torch.no_grad(): # Forward pass outputs = model.sample(images) targets = targets.type_as(outputs) # Convert genders: (batch_size, 1) -> (batch_size,) genders = genders.type_as(outputs).view(-1).bool() # Calculate accuracy eval_acc, eval_attr_acc = calculateAccuracy( outputs, targets) # Calculate confusion matrices batch_cm_m, batch_cm_f = calculateGenderConfusionMatrices( outputs, targets, genders) if cm_m is None and cm_f is None: cm_m = batch_cm_m cm_f = batch_cm_f else: cm_m = list(cm_m) cm_f = list(cm_f) for j in range(len(cm_m)): cm_m[j] += batch_cm_m[j] cm_f[j] += batch_cm_f[j] cm_m = tuple(cm_m) cm_f = tuple(cm_f) # Update averages mean_accuracy.update(eval_acc, images.size(0)) attr_accuracy.update(eval_attr_acc, images.size(0)) s_eval = ('%10s Accuracy: %.4f') % ( '%g/%g' % (epoch, opt.num_epochs - 1), mean_accuracy.avg) # Calculate fairness metrics on final batch if i == dev_batch_count - 1: avg_equality_gap_0, avg_equality_gap_1, attr_equality_gap_0, attr_equality_gap_1 = \ calculateEqualityGap(cm_m, cm_f) avg_parity_gap, attr_parity_gap = calculateParityGap( cm_m, cm_f) s_eval += ( ', Equality Gap 0: %.4f, Equality Gap 1: %.4f, Parity Gap: %.4f' ) % (avg_equality_gap_0, avg_equality_gap_1, avg_parity_gap) pbar.set_description(s_eval) # Create output dirs for dir in [opt.log_dir, opt.weights_dir]: if not os.path.exists(dir): os.makedirs(dir) subdir = os.path.join(dir, opt.out_dir) if not os.path.exists(subdir): os.makedirs(subdir) log_dir = os.path.join(opt.log_dir, opt.out_dir) weights_dir = os.path.join(opt.weights_dir, opt.out_dir) # Log results with open(os.path.join(log_dir, opt.log), 'a+') as f: f.write('{}\n'.format(s_train)) f.write('{}\n'.format(s_eval)) save_attr_metrics( attr_accuracy.avg, attr_equality_gap_0, attr_equality_gap_1, attr_parity_gap, os.path.join(log_dir, opt.attr_metrics + '_' + str(epoch))) # Check against best accuracy mean_eval_acc = mean_accuracy.avg.cpu().item() if mean_eval_acc > best_acc: best_acc = mean_eval_acc save_best = True # Create checkpoint checkpoint = { 'epoch': epoch, 'model': model.state_dict(), 'optimizers': { 'primary': primary_optimizer.state_dict(), 'adversarial': adversarial_optimizer.state_dict() if not baseline else None, }, 'best_acc': best_acc, 'baseline': baseline, 'hyp': { 'hidden_size': hidden_size, 'lambd': lambd } } # Save last checkpoint torch.save(checkpoint, os.path.join(weights_dir, 'last.pkl')) # Save best checkpoint if save_best: torch.save(checkpoint, os.path.join(weights_dir, 'best.pkl')) save_best = False # Save backup every 10 epochs (optional) if (epoch + 1) % save_after_x_epochs == 0: # Save our models print('!!! saving models at epoch: ' + str(epoch)) torch.save( checkpoint, os.path.join(weights_dir, 'checkpoint-%d-%d.pkl' % (epoch + 1, 1))) # Delete checkpoint del checkpoint print('Done!')