Exemplo n.º 1
0
    def test_gan(self, epoch):
        for part in ('train', 'val', 'test'):
            ds = dataset.load_celeba('CelebA',
                                     batch_size,
                                     part=part,
                                     consumer='translator',
                                     smallbatch=10)

            element = ds.make_one_shot_iterator().get_next()
            sess = K.get_session()
            imgs, labels = sess.run(element)

            labels = 1 - labels
            print(labels)
            rec_imgs = self.generator.predict([imgs, labels])
            src_real, _, cls_real = self.discriminator.predict(imgs)
            src_fake, _, cls_fake = self.discriminator.predict(rec_imgs)
            for r, f, sr, cr, sf, cf in zip(imgs, rec_imgs, src_real, cls_real,
                                            src_fake, cls_fake):
                plot_images([img_renorm(r), img_renorm(f)])
                print('real: ' + str(sr) + ' cls ' + str(cr) + '   fake: ' +
                      str(sf) + ' cls ' + str(cf))
Exemplo n.º 2
0
    x = Dense(256, kernel_regularizer=l2(weight_decay))(facenet_outputs)
    x = BatchNormalization(gamma_regularizer=l2(weight_decay),
                           beta_regularizer=l2(weight_decay))(x)
    x = ReLU()(x)
    x = Dropout(dropout_rate)(x)

    x = Dense(1,
              activation='sigmoid',
              kernel_regularizer=l2(weight_decay),
              bias_regularizer=l2(weight_decay))(x)

    return Model(facenet.input, x)


x_train, train_size = dataset.load_celeba('CelebA',
                                          batch_size,
                                          part='train',
                                          consumer='classifier')
x_val, val_size = dataset.load_celeba('CelebA',
                                      batch_size,
                                      part='val',
                                      consumer='classifier')


def train(learning_rate=0.0002):

    classifier = create_model()

    opt = keras.optimizers.Adam(lr=learning_rate, beta_1=0.5, epsilon=1e-08)
    classifier.compile(optimizer=opt,
                       loss='binary_crossentropy',
                       metrics=['accuracy'])
Exemplo n.º 3
0
    z_mean, z_log_var = encoder(input)
    z = Lambda(sampling, name='z')([z_mean, z_log_var])
    rec_img = decoder(z)

    model = Model(input, rec_img, name='vae')

    if return_kl_loss_op:
        kl_loss = -0.5 * K.mean(1 + z_log_var \
                                 - K.square(z_mean) \
                                 - K.exp(z_log_var), axis=-1)
        return model, kl_loss
    else:
        return model


x_train, train_size = dataset.load_celeba('CelebA', batch_size, part='train')
x_val, val_size = dataset.load_celeba('CelebA', batch_size, part='val')


def train(selected_pm_layers,
          alpha=1.0,
          latent_dim=1024,
          learning_rate=0.0005,
          norm_func_e=InstanceNormalization,
          norm_func_d=InstanceNormalization,
          trained_model=None):
    from tensorflow.keras.models import model_from_json

    #facenet model structure: https://github.com/serengil/tensorflow-101/blob/master/model/facenet_model.json
    pm = model_from_json(open("model/facenet_model.json", "r").read())
Exemplo n.º 4
0
def get_masked_img(img, bbox):
    """Get the original image with a black mask at the given bbox."""
    img = img.copy()
    img_draw = ImageDraw.Draw(img)
    img_draw.rectangle(bbox, fill='black')

    return img


if __name__ == '__main__':
    import os
    from dataset import load_celeba
    from main import Width, Height, mask_size
    os.makedirs('dataset/celeba/masked/images/', exist_ok=True)

    test_paths = load_celeba('small-test')

    for path in test_paths:
        img = Image.open(path)
        img = img.resize((Height, Width))
        x1 = (Width - mask_size) // 2
        x2 = (Width + mask_size) // 2
        y1 = (Height - mask_size) // 2
        y2 = (Height + mask_size) // 2
        bbox = (x1, y1, x2, y2)
        masked_img = get_masked_img(img, bbox)

        # masked_img.show()
        basename = os.path.basename(path)
        print(basename)
        masked_img.save(f'dataset/celeba/masked/images/{basename}')
Exemplo n.º 5
0
    def train(self, epochs):

        x_train, train_size = dataset.load_celeba('CelebA',
                                                  batch_size,
                                                  part='train',
                                                  consumer='translator')
        x_val, val_size = dataset.load_celeba('CelebA',
                                              batch_size,
                                              part='val',
                                              consumer='translator')

        x_train_itr = x_train.make_one_shot_iterator()
        x_train_next = x_train_itr.get_next()
        x_val_itr = x_val.make_one_shot_iterator()
        x_val_next = x_val_itr.get_next()

        steps_per_epoch = train_size // batch_size
        validation_steps = val_size // batch_size
        sess = K.get_session()

        for epoch in range(epochs):
            epoch_start_time = time.time()
            for step in range(steps_per_epoch):
                train_img = sess.run(x_train_next)

                # Train discriminator
                d_loss = self.discriminator_training_model.train_on_batch(
                    train_img)

                # Train Generator
                if (step + 1) % self.n_critic == 0:
                    g_loss = self.generator_training_model.train_on_batch(
                        train_img)

                    # Btw, print log...
                    et = time.time() - epoch_start_time
                    eta = et * (steps_per_epoch - step - 1) / (step + 1)
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    eta = str(datetime.timedelta(seconds=eta))[:-7]
                    log = "{}/{}   - Elapsed: {}, ETA: {}  - d_loss: {:.4f} , g_loss: {:.4f}".format(
                        step + 1, steps_per_epoch, et, eta, d_loss, g_loss)
                    print(log)
            # validate per epoch
            d_val_loss = 0
            g_val_loss = 0
            img_rec_acc = 0
            for step in range(validation_steps):
                val_img = sess.run(x_val_next)
                d_val_loss += self.discriminator_training_model.test_on_batch(
                    val_img)
                g_val_loss += self.generator_training_model.test_on_batch(
                    val_img)
                # rec_img = self.generator.predict_on_batch(val_img)
                # img_rec_acc += K.mean(1 - mae(K.batch_flatten(val_img), K.batch_flatten(rec_img)) / 2)

            d_val_loss /= validation_steps
            g_val_loss /= validation_steps
            img_rec_acc /= validation_steps

            log = "ephoch {}   - d_val_loss: {:.4f} , g_val_loss: {:.4f} , img_rec_acc: {:.4f}  - d_loss: {:.4f} , g_loss: {:.4f}".format(
                epoch + 1, d_val_loss, g_val_loss, img_rec_acc, d_loss, g_loss)
            print(log)

            # save model per epoch
            save_model(
                self.generator,
                'face_gan_epoch{:02d}-d_loss{:.4f}-g_loss{:.4f}-acc{:.4f}'.
                format(epoch + 1, d_val_loss, g_val_loss, img_rec_acc))

            # test the model
            self.test_gan(epoch)

            # update learning rate
            lr = K.get_value(self.discriminator_training_model.optimizer.lr)
            lr *= lr_decay_ratio
            K.set_value(self.discriminator_training_model.optimizer.lr, lr)
            K.set_value(self.generator_training_model.optimizer.lr, lr)
            lr = K.get_value(self.discriminator_training_model.optimizer.lr)
            print(str(lr))
            lr = K.get_value(self.generator_training_model.optimizer.lr)
            print(str(lr))
Exemplo n.º 6
0
def main():
    # pdb.set_trace()

    # Determine device
    device = getDevice(opt.gpu_id)
    num_classes = 39

    # Create data loaders
    data_loaders = load_celeba(splits=['test'], batch_size=opt.batch_size, subset_percentage=opt.subset_percentage)
    test_data_loader = data_loaders['test']

    # Load checkpoint
    checkpoint = torch.load(os.path.join(opt.weights_dir, opt.out_dir, opt.weights), map_location=device)
    baseline = checkpoint['baseline']
    hidden_size = checkpoint['hyp']['hidden_size']

    # Create model
    if baseline:
        model = BaselineModel(hidden_size)
    else:
        model = OurModel(hidden_size)

    # Convert device
    model = model.to(device)

    test_batch_count = len(test_data_loader)

    # Load model
    model.load_state_dict(checkpoint['model'])    

    # Evaluate
    model.eval()

    # Initialize meters, confusion matrices, and metrics
    mean_accuracy = AverageMeter()
    attr_accuracy = AverageMeter((1, num_classes), device=device)
    cm_m = None
    cm_f = None
    attr_equality_gap_0 = None
    attr_equality_gap_1 = None
    attr_parity_gap = None

    with tqdm(enumerate(test_data_loader), total=test_batch_count) as pbar:
        for i, (images, targets, genders, protected_labels) in pbar:
            images = Variable(images.to(device))
            targets = Variable(targets.to(device))
            genders = Variable(genders.to(device))

            with torch.no_grad():
                # Forward pass
                outputs = model.sample(images)
                targets = targets.type_as(outputs)

                # Convert genders: (batch_size, 1) -> (batch_size,)
                genders = genders.type_as(outputs).view(-1).bool()

                # Calculate accuracy
                eval_acc, eval_attr_acc = calculateAccuracy(outputs, targets)

                # Calculate confusion matrices
                batch_cm_m, batch_cm_f = calculateGenderConfusionMatrices(outputs, targets, genders)
                if cm_m is None and cm_f is None:
                    cm_m = batch_cm_m
                    cm_f = batch_cm_f
                else:
                    cm_m = list(cm_m)
                    cm_f = list(cm_f)
                    for j in range(len(cm_m)):
                        cm_m[j] += batch_cm_m[j]
                        cm_f[j] += batch_cm_f[j]
                    cm_m = tuple(cm_m)
                    cm_f = tuple(cm_f)

                # Update averages
                mean_accuracy.update(eval_acc, images.size(0))
                attr_accuracy.update(eval_attr_acc, images.size(0))

                s_test = ('Accuracy: %.4f') % (mean_accuracy.avg)

                # Calculate fairness metrics on final batch
                if i == test_batch_count - 1:
                    avg_equality_gap_0, avg_equality_gap_1, attr_equality_gap_0, attr_equality_gap_1 = \
                        calculateEqualityGap(cm_m, cm_f)
                    avg_parity_gap, attr_parity_gap = calculateParityGap(cm_m, cm_f)
                    s_test += (', Equality Gap 0: %.4f, Equality Gap 1: %.4f, Parity Gap: %.4f') % (avg_equality_gap_0, avg_equality_gap_1, avg_parity_gap)

                pbar.set_description(s_test)


        # Log results
        log_dir = os.path.join(opt.log_dir, opt.out_dir)
        with open(os.path.join(log_dir, opt.log), 'a+') as f:
            f.write('{}\n'.format(s_test))
        save_attr_metrics(attr_accuracy.avg, attr_equality_gap_0, attr_equality_gap_1, attr_parity_gap,
                          os.path.join(log_dir, opt.attr_metrics))

    print('Done!')
Exemplo n.º 7
0
    def train(self):
        x_train, train_size = dataset.load_celeba('CelebA',
                                                  batch_size,
                                                  part='train',
                                                  consumer='translator')
        x_val, val_size = dataset.load_celeba('CelebA',
                                              batch_size,
                                              part='val',
                                              consumer='translator')

        x_train_itr = x_train.make_one_shot_iterator()
        x_train_next = x_train_itr.get_next()
        x_val_itr = x_val.make_one_shot_iterator()
        x_val_next = x_val_itr.get_next()

        steps_per_epoch = train_size // batch_size
        self.lr_decay_value_d = learning_rate / (
            ((epochs - epochs_lr_start_decay) *
             (steps_per_epoch // steps_4_log_and_lrupdate)) + 1)
        self.lr_decay_value_g = learning_rate / (
            ((epochs - epochs_lr_start_decay) *
             (steps_per_epoch // steps_4_log_and_lrupdate)) + 1)

        validation_steps = val_size // batch_size
        sess = K.get_session()

        def binary_accuracy(y_true, y_pre):
            return np.mean(np.fabs(y_true - y_pre) < 0.5)

        for epoch in range(epochs):
            epoch_start_time = time.time()
            d_loss = np.array([0., 0., 0., 0.])
            g_loss = np.array([0., 0., 0., 0.])
            for step in range(steps_per_epoch):
                train_img, train_label = sess.run(x_train_next)
                train_target_label = 1 - train_label

                # Train discriminator
                d_loss += self.discriminator_training_model.train_on_batch(
                    [train_img, train_label, train_target_label])

                # Train Generator
                if (step + 1) % self.n_critic == 0:
                    g_loss += self.generator_training_model.train_on_batch(
                        [train_img, train_label, train_target_label])

                # print log...
                if (step + 1) % steps_4_log_and_lrupdate == 0:
                    et = time.time() - epoch_start_time
                    eta = et * (steps_per_epoch - step - 1) / (step + 1)
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    eta = str(datetime.timedelta(seconds=eta))[:-7]
                    '''
                    stat latest itrs_4_log_and_lrupdate itrs of losses, to make the data not too old and not too variant
                    '''
                    d_loss /= steps_4_log_and_lrupdate
                    g_loss /= (steps_4_log_and_lrupdate / self.n_critic)
                    log = "{}/{}   - Elapsed: {}, ETA: {}  - d_loss: {:.4f} , w {:.4f} , gp {:.4f} , cls {:.4f}   , g_loss: {:.4f} , w {:.4f} , rec {:.4f} , cls {:.4f}"\
                        .format(step+1, steps_per_epoch, et, eta, d_loss[0], d_loss[1], d_loss[2], d_loss[3], g_loss[0], g_loss[1], g_loss[2], g_loss[3])
                    print(log)
                    d_loss.fill(0.)
                    g_loss.fill(0.)

                    # update learning rate
                    if (epoch + 1) > epochs_lr_start_decay:
                        self.update_lr(self.discriminator_training_model,
                                       self.lr_decay_value_d)
                        self.update_lr(self.generator_training_model,
                                       self.lr_decay_value_g)

            # validate per epoch
            '''
            g: img_acc, gen_img_cls_acc, div_real_and_fake
            d: div_real_and_fake, real_cls_acc
            higher div_real_and_fake means better discriminator but worse generator
            use generator and discriminator, compose to get above metrics
            '''

            d_val_loss = np.array([0., 0., 0., 0.])
            g_val_loss = np.array([0., 0., 0., 0.])

            img_acc = 0
            gen_cls_acc = 0
            real_cls_acc = 0
            for step in range(validation_steps):
                val_img, val_label = sess.run(x_val_next)
                val_target_label = 1 - val_label
                d_val_loss += self.discriminator_training_model.test_on_batch(
                    [val_img, val_label, val_target_label])
                g_val_loss += self.generator_training_model.test_on_batch(
                    [val_img, val_label, val_target_label])

                rec_img = self.generator.predict_on_batch(
                    [val_img, val_target_label])
                _, _, r_cls = self.discriminator.predict_on_batch(val_img)
                _, _, g_cls = self.discriminator.predict_on_batch(rec_img)
                rec_img = self.generator.predict_on_batch([val_img, val_label])

                real_cls_acc += binary_accuracy(val_label, r_cls.flatten())
                gen_cls_acc += binary_accuracy(val_target_label,
                                               g_cls.flatten())
                img_acc += (1 - np.mean(
                    np.fabs(val_img.flatten() - rec_img.flatten())) / 2)

            d_val_loss /= validation_steps
            g_val_loss /= validation_steps

            img_acc /= validation_steps
            gen_cls_acc /= validation_steps
            div_real_and_fake = -d_val_loss[
                1]  # it's wasserstein loss of d indeed, but turn to positive in order to be easier understanding
            real_cls_acc /= validation_steps


            log = 'ephoch {}  validation:  - d_loss: {:.4f} , w {:.4f} , gp {:.4f} , cls {:.4f}   , '\
                        'g_loss: {:.4f} , w {:.4f} , rec {:.4f} , cls {:.4f}  -  '\
                        'img_acc: {:.4f},  gen_cls_acc: {:.4f}, real_cls_acc: {:.4f}, div_real_and_fake: {:.4f}'\
                        .format(epoch+1, d_val_loss[0], d_val_loss[1], d_val_loss[2], d_val_loss[3],
                                g_val_loss[0], g_val_loss[1], g_val_loss[2], g_val_loss[3],
                                img_acc, gen_cls_acc, real_cls_acc, div_real_and_fake)
            print(log)

            # save model per epoch
            if epoch > epochs_lr_start_decay:
                save_model(
                    self.generator,
                    'face_generator_epoch{:02d}-acc{:.4f}-g_cls{:.4f}-r_cls{:.4f}'
                    .format(epoch + 1, img_acc, gen_cls_acc, real_cls_acc))
                save_model(
                    self.discriminator,
                    'face_discriminator_epoch{:02d}-acc{:.4f}-g_cls{:.4f}-r_cls{:.4f}'
                    .format(epoch + 1, img_acc, gen_cls_acc, real_cls_acc))

            # test the model
            if epoch == 5 or epoch == 15 or epoch == (epochs - 1):
                self.test_gan(epoch)
Exemplo n.º 8
0
def main():

    # pdb.set_trace()
    # Model Hyperparams
    random.seed(opt.random_seed)
    baseline = opt.baseline
    hidden_size = opt.hidden_size
    lambd = opt.lambd
    learning_rate = opt.learning_rate
    adv_learning_rate = opt.adv_learning_rate
    save_after_x_epochs = 10
    num_classes = 39

    # Determine device
    device = getDevice(opt.gpu_id)

    # Create data loaders
    data_loaders = load_celeba(splits=['train', 'valid'], batch_size=opt.batch_size, subset_percentage=opt.subset_percentage, \
         protected_percentage = opt.protected_percentage, balance_protected=opt.balance_protected)
    train_data_loader = data_loaders['train']
    dev_data_loader = data_loaders['valid']

    # Load checkpoint
    checkpoint = None
    if opt.weights != '':
        checkpoint = torch.load(opt.weights, map_location=device)
        baseline = checkpoint['baseline']
        hidden_size = checkpoint['hyp']['hidden_size']

    # Create model
    if baseline:
        model = BaselineModel(hidden_size)
    else:
        model = OurModel(hidden_size)

    # Convert device
    model = model.to(device)

    # Loss criterion
    criterion = nn.BCEWithLogitsLoss()  # For multi-label classification
    if not baseline:
        adversarial_criterion = nn.BCEWithLogitsLoss()

    # Create optimizers
    primary_optimizer_params = list(model.encoder.parameters()) + list(
        model.classifier.parameters())
    primary_optimizer = torch.optim.Adam(primary_optimizer_params,
                                         lr=learning_rate)
    if not baseline:
        adversarial_optimizer_params = list(model.adv_head.parameters())
        adversarial_optimizer = torch.optim.Adam(adversarial_optimizer_params,
                                                 lr=adv_learning_rate)

    start_epoch = 0
    best_acc = 0.0
    save_best = False

    train_batch_count = len(train_data_loader)
    dev_batch_count = len(dev_data_loader)

    if checkpoint is not None:
        # Load model weights
        model.load_state_dict(checkpoint['model'])

        # Load metadata to resume training
        if opt.resume:
            if checkpoint['epoch']:
                start_epoch = checkpoint['epoch'] + 1
            if checkpoint['best_acc']:
                best_acc = checkpoint['best_acc']
            if checkpoint['hyp']['lambd']:
                lambd = checkpoint['hyp']['lambd']
            if checkpoint['optimizers']['primary']:
                primary_optimizer.load_state_dict(
                    checkpoint['optimizers']['primary'])
            if checkpoint['optimizers']['adversarial']:
                adversarial_optimizer.load_state_dict(
                    checkpoint['optimizers']['adversarial'])

    # Train loop
    # pdb.set_trace()
    adversarial_loss = None
    for epoch in range(start_epoch, opt.num_epochs):

        # Set model to train mode
        model.train()

        # Initialize meters and confusion matrices
        mean_accuracy = AverageMeter(device=device)
        cm_m = None
        cm_f = None

        with tqdm(enumerate(train_data_loader),
                  total=train_batch_count) as pbar:  # progress bar
            for i, (images, targets, genders, protected_labels) in pbar:

                # Shape: torch.Size([batch_size, 3, crop_size, crop_size])
                images = Variable(images.to(device))

                # Shape: torch.Size([batch_size, 39])
                targets = Variable(targets.to(device))

                # Shape: torch.Size([batch_size])
                genders = Variable(genders.to(device))

                # Shape: torch.Size([batch_size])
                protected_labels = Variable(
                    protected_labels.type(torch.BoolTensor).to(device))

                # Forward pass
                if baseline:
                    outputs, (a, a_detached) = model(images)
                else:
                    outputs, (a, a_detached) = model(images, protected_labels)
                targets = targets.type_as(outputs)
                genders = genders.type_as(outputs)

                # Zero out buffers
                # model.zero_grad() # either model or optimizer.zero_grad() is fine
                primary_optimizer.zero_grad()

                # CrossEntropyLoss is expecting:
                # Input:  (N, C) where C = number of classes
                classification_loss = criterion(outputs, targets)

                if baseline:
                    loss = classification_loss
                else:
                    if a != None:
                        adversarial_loss = adversarial_criterion(
                            a, genders[protected_labels])
                        loss = classification_loss - lambd * adversarial_loss

                        # Backward pass (Primary)
                        loss.backward()
                        primary_optimizer.step()

                        # Zero out buffers
                        adversarial_optimizer.zero_grad()

                        # Calculate loss for adversarial head
                        adversarial_loss = adversarial_criterion(
                            a_detached, genders[protected_labels])

                        # Backward pass (Adversarial)
                        adversarial_loss.backward()
                        adversarial_optimizer.step()
                    else:
                        loss = classification_loss

                        # Backward pass (Primary)
                        loss.backward()
                        primary_optimizer.step()

                # Convert genders: (batch_size, 1) -> (batch_size,)
                genders = genders.view(-1).bool()

                # Calculate accuracy
                train_acc, _ = calculateAccuracy(outputs, targets)

                # Calculate confusion matrices
                batch_cm_m, batch_cm_f = calculateGenderConfusionMatrices(
                    outputs, targets, genders)
                if cm_m is None and cm_f is None:
                    cm_m = batch_cm_m
                    cm_f = batch_cm_f
                else:
                    cm_m = list(cm_m)
                    cm_f = list(cm_f)
                    for j in range(len(cm_m)):
                        cm_m[j] += batch_cm_m[j]
                        cm_f[j] += batch_cm_f[j]
                    cm_m = tuple(cm_m)
                    cm_f = tuple(cm_f)

                # Update averages
                mean_accuracy.update(train_acc, images.size(0))

                if baseline:
                    s_train = ('%10s Loss: %.4f, Accuracy: %.4f') % (
                        '%g/%g' % (epoch, opt.num_epochs - 1), loss.item(),
                        mean_accuracy.avg)
                else:
                    if adversarial_loss == None:
                        s_train = (
                            '%10s Classification Loss: %.4f, Total Loss: %.4f, Accuracy: %.4f'
                        ) % ('%g/%g' % (epoch, opt.num_epochs - 1),
                             classification_loss.item(), loss.item(),
                             mean_accuracy.avg)
                    else:
                        s_train = (
                            '%10s Classification Loss: %.4f, Adversarial Loss: %.4f, Total Loss: %.4f, Accuracy: %.4f'
                        ) % ('%g/%g' % (epoch, opt.num_epochs - 1),
                             classification_loss.item(),
                             adversarial_loss.item(), loss.item(),
                             mean_accuracy.avg)

                # Calculate fairness metrics on final batch
                if i == train_batch_count - 1:
                    avg_equality_gap_0, avg_equality_gap_1, _, _ = calculateEqualityGap(
                        cm_m, cm_f)
                    avg_parity_gap, _ = calculateParityGap(cm_m, cm_f)
                    s_train += (
                        ', Equality Gap 0: %.4f, Equality Gap 1: %.4f, Parity Gap: %.4f'
                    ) % (avg_equality_gap_0, avg_equality_gap_1,
                         avg_parity_gap)

                pbar.set_description(s_train)

        # end batch ------------------------------------------------------------------------------------------------

        # Evaluate
        # pdb.set_trace()
        model.eval()

        # Initialize meters, confusion matrices, and metrics
        mean_accuracy = AverageMeter()
        attr_accuracy = AverageMeter((1, num_classes), device=device)
        cm_m = None
        cm_f = None
        attr_equality_gap_0 = None
        attr_equality_gap_1 = None
        attr_parity_gap = None

        with tqdm(enumerate(dev_data_loader), total=dev_batch_count) as pbar:
            for i, (images, targets, genders, protected_labels) in pbar:
                images = Variable(images.to(device))
                targets = Variable(targets.to(device))
                genders = Variable(genders.to(device))

                with torch.no_grad():
                    # Forward pass
                    outputs = model.sample(images)
                    targets = targets.type_as(outputs)

                    # Convert genders: (batch_size, 1) -> (batch_size,)
                    genders = genders.type_as(outputs).view(-1).bool()

                    # Calculate accuracy
                    eval_acc, eval_attr_acc = calculateAccuracy(
                        outputs, targets)

                    # Calculate confusion matrices
                    batch_cm_m, batch_cm_f = calculateGenderConfusionMatrices(
                        outputs, targets, genders)
                    if cm_m is None and cm_f is None:
                        cm_m = batch_cm_m
                        cm_f = batch_cm_f
                    else:
                        cm_m = list(cm_m)
                        cm_f = list(cm_f)
                        for j in range(len(cm_m)):
                            cm_m[j] += batch_cm_m[j]
                            cm_f[j] += batch_cm_f[j]
                        cm_m = tuple(cm_m)
                        cm_f = tuple(cm_f)

                    # Update averages
                    mean_accuracy.update(eval_acc, images.size(0))
                    attr_accuracy.update(eval_attr_acc, images.size(0))

                    s_eval = ('%10s Accuracy: %.4f') % (
                        '%g/%g' %
                        (epoch, opt.num_epochs - 1), mean_accuracy.avg)

                    # Calculate fairness metrics on final batch
                    if i == dev_batch_count - 1:
                        avg_equality_gap_0, avg_equality_gap_1, attr_equality_gap_0, attr_equality_gap_1 = \
                            calculateEqualityGap(cm_m, cm_f)
                        avg_parity_gap, attr_parity_gap = calculateParityGap(
                            cm_m, cm_f)
                        s_eval += (
                            ', Equality Gap 0: %.4f, Equality Gap 1: %.4f, Parity Gap: %.4f'
                        ) % (avg_equality_gap_0, avg_equality_gap_1,
                             avg_parity_gap)

                    pbar.set_description(s_eval)

        # Create output dirs
        for dir in [opt.log_dir, opt.weights_dir]:
            if not os.path.exists(dir):
                os.makedirs(dir)
            subdir = os.path.join(dir, opt.out_dir)
            if not os.path.exists(subdir):
                os.makedirs(subdir)
        log_dir = os.path.join(opt.log_dir, opt.out_dir)
        weights_dir = os.path.join(opt.weights_dir, opt.out_dir)

        # Log results
        with open(os.path.join(log_dir, opt.log), 'a+') as f:
            f.write('{}\n'.format(s_train))
            f.write('{}\n'.format(s_eval))
        save_attr_metrics(
            attr_accuracy.avg, attr_equality_gap_0, attr_equality_gap_1,
            attr_parity_gap,
            os.path.join(log_dir, opt.attr_metrics + '_' + str(epoch)))

        # Check against best accuracy
        mean_eval_acc = mean_accuracy.avg.cpu().item()
        if mean_eval_acc > best_acc:
            best_acc = mean_eval_acc
            save_best = True

        # Create checkpoint
        checkpoint = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizers': {
                'primary':
                primary_optimizer.state_dict(),
                'adversarial':
                adversarial_optimizer.state_dict() if not baseline else None,
            },
            'best_acc': best_acc,
            'baseline': baseline,
            'hyp': {
                'hidden_size': hidden_size,
                'lambd': lambd
            }
        }

        # Save last checkpoint
        torch.save(checkpoint, os.path.join(weights_dir, 'last.pkl'))

        # Save best checkpoint
        if save_best:
            torch.save(checkpoint, os.path.join(weights_dir, 'best.pkl'))
            save_best = False

        # Save backup every 10 epochs (optional)
        if (epoch + 1) % save_after_x_epochs == 0:
            # Save our models
            print('!!! saving models at epoch: ' + str(epoch))
            torch.save(
                checkpoint,
                os.path.join(weights_dir,
                             'checkpoint-%d-%d.pkl' % (epoch + 1, 1)))

        # Delete checkpoint
        del checkpoint

    print('Done!')