def train_model(cnn, criterion, optimizer, num_epochs=100):
    loss_history = []
    train_acc_history = []
    val_acc_history = []
    epoch_history = []
    best_val_acc = 0.0
    p = 0.5
    learning_rate = [0.001]
    for lr in learning_rate:
        lr_msg = 'Learning Rate for this model: {}'.format(lr)
        print(lr_msg)
        epoch_history.append(lr_msg)
        p_msg = 'Dropout p for this model: {}'.format(p)
        print(p_msg)
        epoch_history.append(p_msg)
        cnn = CNN(p)
        if use_gpu:
            cnn.cuda()
        for epoch in range(num_epochs):
            optimizer = torch.optim.SGD(cnn.parameters(), lr=lr, momentum=0.9)

            print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
            #             print('Learning Rate for this epoch: {}'.format(learning_rate))
            print('Learning Rate for this epoch: {}'.format(lr))

            i = 0
            for gender_labels, race_labels, img_names, images in train_loader:
                gender_labels = torch.from_numpy(np.asarray(gender_labels))
                race_labels = torch.from_numpy(np.asarray(race_labels))
                images = Variable(images)
                labels = Variable(gender_labels)
                race_labels = Variable(race_labels)
                if use_gpu:
                    images, labels = images.cuda(), labels.cuda()
                    race_labels = race_labels.cuda()

                pred_labels, _, _ = cnn(images)

                loss = criterion(pred_labels, labels)
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

                if (i + 1) % 5 == 0:
                    print(
                        'Epoch [%d/%d], Iter [%d/%d] CNN Loss: %.4f Adversary Loss: %.4f'
                        % (epoch + 1, num_epochs, i + 1, len(train_data) // 50,
                           loss.data, nn_loss.data))
                i = i + 1

                if epoch == 0 or epoch % 5 == 0 or epoch == num_epochs - 1:
                    train_acc = check_acc(cnn, train_loader)
                    train_acc_history.append(train_acc)
                    train_msg = 'Train accuracy for epoch {}: {} '.format(
                        epoch + 1, train_acc)
                    print(train_msg)
                    epoch_history.append(train_msg)

                    val_acc = check_acc(cnn, test_loader)
                    val_acc_history.append(val_acc)
                    val_msg = 'Validation accuracy for epoch {}: {} '.format(
                        epoch + 1, val_acc)
                    print(val_msg)
                    epoch_history.append(val_msg)

                    is_best = val_acc > best_val_acc
                    best_val_acc = max(val_acc, best_val_acc)
                    save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'state_dict': cnn.state_dict(),
                            'best_val_acc': best_val_acc,
                            'optimizer': optimizer.state_dict()
                        }, is_best)

                np.savetxt("training_log_cnn.out", epoch_history, fmt='%s')
	transforms.CenterCrop(227),
	transforms.ToTensor()
	])

print('Loading images...')
batch_size = 50
root='UTKFace/val'

val_data = gender_race_dataset("val_labels_all.csv", root, test_transform)
val_loader = torch.utils.data.DataLoader(val_data,
	batch_size=batch_size,shuffle=False)

if model == "cnn":
    print("Using CNN")

    cnn = CNN()
    if use_gpu:
        cnn.cuda()
    optimizer = torch.optim.SGD(cnn.parameters(),lr=0.001,momentum=0.9)

    SAVED_MODEL_PATH = 'cnn_model_best_SGD_adversary.pth.tar'
    checkpoint = torch.load(SAVED_MODEL_PATH)
    cnn.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch = checkpoint['epoch']
    best_val_acc = checkpoint['best_val_acc']
    
    print("best model saved from epoch: ", epoch)

    print("best val_acc = ", best_val_acc)
batch_size = 50
validation_split = .2
shuffle_dataset = True
random_seed = 42

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=batch_size,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data,
                                         batch_size=batch_size,
                                         shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=batch_size,
                                          shuffle=False)

cnn = CNN()
if use_gpu:
    print('Using GPU for cnn')
    cnn.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(cnn.parameters(), lr=0.001, momentum=0.9)


def train_model(cnn, criterion, optimizer, num_epochs=100):
    loss_history = []
    train_acc_history = []
    val_acc_history = []
    epoch_history = []
    best_val_acc = 0.0
    p = 0.5
def train_model(cnn, adversary, criterion, nn_criterion, optimizer, nn_optimizer, num_epochs = 100):
    loss_history = []
    train_acc_history = []
    val_acc_history = []
    epoch_history = []
#     learning_rate = 0.001
#     learning_rate = np.logspace(-6,-2, num=15)
#     learning_rate = np.logspace(-3,-2, num=8)
    best_val_acc = 0.0
#     alpha = 1.0
#     p_vals = np.array([0.5])
    p = 0.5
    learning_rate = [0.0026826957952797246] #0.001 for Adam
    alpha = [0.9]
#     alpha = np.array([0.9, 1.0])

    for layer in range(1,2):
        for lr in learning_rate:
            for a in alpha:
                layer_msg = 'Layer for this model: {}'.format(layer)
                print(layer_msg)
                epoch_history.append(layer_msg)
                lr_msg = 'Learning Rate for this model: {}'.format(lr)
                print(lr_msg)
                epoch_history.append(lr_msg)
                p_msg = 'Dropout p for this model: {}'.format(p)
                print(p_msg)
                epoch_history.append(p_msg)
                a_msg = 'Alpha for this model: {}'.format(a)
                print(a_msg)
                epoch_history.append(a_msg)
                cnn = CNN(p)
                cnn.cuda()
                adversary = NN(512)
                adversary.cuda()

                for epoch in range(num_epochs):
                    optimizer = torch.optim.SGD(cnn.parameters(),lr=lr,momentum=0.9)
                    nn_optimizer = torch.optim.SGD(adversary.parameters(),lr=lr,momentum=0.9)
#                     optimizer = torch.optim.Adam(cnn.parameters(), lr=lr, betas=(0.9, 0.999))
#                     nn_optimizer = torch.optim.Adam(cnn.parameters(), lr=lr, betas=(0.9, 0.999))

                    print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
        #             print('Learning Rate for this epoch: {}'.format(learning_rate))
#                     print('Learning Rate for this epoch: {}'.format(lr))

                    i = 0
                    for gender_labels, race_labels, img_names, images in train_loader:
                        gender_labels = torch.from_numpy(np.asarray(gender_labels))
                        race_labels = torch.from_numpy(np.asarray(race_labels))
                        images = Variable(images)
                        labels = Variable(gender_labels)
                        race_labels = Variable(race_labels)
                        if use_gpu:
                            images,labels = images.cuda(),labels.cuda()
                            race_labels = race_labels.cuda()

                        pred_labels, penultimate_weights, layer_2 = cnn(images)
                        if layer == 1:
                            nn_pred_labels = adversary(penultimate_weights)
                        if layer == 2:
                            nn_pred_labels = adversary(layer_2)

                        nn_loss = nn_criterion(nn_pred_labels, race_labels)   
                        
                        cnn_loss = criterion(pred_labels,labels)

                        #loss for gender prediction model
#                         loss = cnn_loss - alpha*nn_loss  
                        loss = cnn_loss - a*nn_loss
                        optimizer.zero_grad()
                        loss.backward(retain_graph = True)
                        optimizer.step()

                        #loss for adversary model
                        nn_optimizer.zero_grad()
                        nn_loss.backward()
                        nn_optimizer.step()

                        if (i+1) % 5 == 0:
                            print ('Epoch [%d/%d], Iter [%d/%d] CNN Loss: %.4f Adversary Loss: %.4f'
                                %(epoch+1, num_epochs, i+1, len(train_data)//50, loss.data, nn_loss.data))
                        i = i + 1

            #         if epoch % 10 == 0:
            #             learning_rate = learning_rate * 0.9

        #             if epoch % 5 ==0 or epoch == num_epochs-1:
                    train_acc = check_acc(cnn,train_loader)
                    train_acc_history.append(train_acc)
                    train_msg = 'Train accuracy for epoch {}: {} '.format(epoch + 1,train_acc)
                    print(train_msg)
                    epoch_history.append(train_msg)

                    val_acc = check_acc(cnn,test_loader)
                    val_acc_history.append(val_acc)
                    val_msg = 'Validation accuracy for epoch {}: {} '.format(epoch + 1,val_acc)
                    print(val_msg)
                    epoch_history.append(val_msg)
                    #plot_performance_curves(train_acc_history,val_acc_history,epoch_history)

                    is_best = val_acc > best_val_acc
                    best_val_acc = max(val_acc,best_val_acc)
                    save_checkpoint(
                            {'epoch':epoch+1,
                            'state_dict':cnn.state_dict(),
                            'best_val_acc':best_val_acc,
                            'optimizer':optimizer.state_dict()},is_best)
                
                    np.savetxt("training_log_cnn_SGD.out", epoch_history, fmt='%s')