def main():
    dataroot = args.dataroot
    lfw_dataroot = args.lfw
    lfw_batch_size = args.lfw_batch_size
    lfw_validation_epoch_interval = args.lfw_validation_epoch_interval
    model_architecture = args.model
    epochs = args.epochs
    resume_path = args.resume_path
    batch_size = args.batch_size
    num_workers = args.num_workers
    validation_dataset_split_ratio = args.valid_split
    embedding_dimension = args.embedding_dim
    pretrained = args.pretrained
    optimizer = args.optimizer
    learning_rate = args.lr
    learning_rate_center_loss = args.center_loss_lr
    center_loss_weight = args.center_loss_weight
    start_epoch = 0

    # Define image data pre-processing transforms
    #   ToTensor() normalizes pixel values between [0, 1]
    #   Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalizes pixel values between [-1, 1]

    #  Size 182x182 RGB image -> Center crop size 160x160 RGB image for more model generalization
    data_transforms = transforms.Compose([
        transforms.RandomCrop(size=160),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # Size 160x160 RGB image
    lfw_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Load the dataset
    dataset = torchvision.datasets.ImageFolder(root=dataroot,
                                               transform=data_transforms)

    # Subset the dataset into training and validation datasets
    num_classes = len(dataset.classes)
    print("\nNumber of classes in dataset: {}".format(num_classes))
    num_validation = int(num_classes * validation_dataset_split_ratio)
    num_train = num_classes - num_validation
    indices = list(range(num_classes))
    np.random.seed(420)
    np.random.shuffle(indices)
    train_indices = indices[:num_train]
    validation_indices = indices[num_train:]

    train_dataset = Subset(dataset=dataset, indices=train_indices)
    validation_dataset = Subset(dataset=dataset, indices=validation_indices)

    print("Number of classes in training dataset: {}".format(
        len(train_dataset)))
    print("Number of classes in validation dataset: {}".format(
        len(validation_dataset)))

    # Define the dataloaders
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  num_workers=num_workers,
                                  shuffle=False)

    validation_dataloader = DataLoader(dataset=validation_dataset,
                                       batch_size=batch_size,
                                       num_workers=num_workers,
                                       shuffle=False)

    lfw_dataloader = torch.utils.data.DataLoader(dataset=LFWDataset(
        dir=lfw_dataroot,
        pairs_path='datasets/LFW_pairs.txt',
        transform=lfw_transforms),
                                                 batch_size=lfw_batch_size,
                                                 num_workers=num_workers,
                                                 shuffle=False)

    # Instantiate model
    if model_architecture == "resnet18":
        model = Resnet18Center(num_classes=num_classes,
                               embedding_dimension=embedding_dimension,
                               pretrained=pretrained)
    elif model_architecture == "resnet34":
        model = Resnet34Center(num_classes=num_classes,
                               embedding_dimension=embedding_dimension,
                               pretrained=pretrained)
    elif model_architecture == "resnet50":
        model = Resnet50Center(num_classes=num_classes,
                               embedding_dimension=embedding_dimension,
                               pretrained=pretrained)
    elif model_architecture == "resnet101":
        model = Resnet101Center(num_classes=num_classes,
                                embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "inceptionresnetv2":
        model = InceptionResnetV2Center(
            num_classes=num_classes,
            embedding_dimension=embedding_dimension,
            pretrained=pretrained)
    print("\nUsing {} model architecture.".format(model_architecture))

    # Load model to GPU or multiple GPUs if available
    flag_train_gpu = torch.cuda.is_available()
    flag_train_multi_gpu = False

    if flag_train_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model.cuda()
        flag_train_multi_gpu = True
        print('Using multi-gpu training.')
    elif flag_train_gpu and torch.cuda.device_count() == 1:
        model.cuda()
        print('Using single-gpu training.')

    # Set loss functions
    criterion_crossentropy = nn.CrossEntropyLoss().cuda()
    criterion_centerloss = CenterLoss(num_classes=num_classes,
                                      feat_dim=embedding_dimension).cuda()

    # Set optimizers
    if optimizer == "sgd":
        optimizer_model = torch.optim.SGD(model.parameters(), lr=learning_rate)
        optimizer_centerloss = torch.optim.SGD(
            criterion_centerloss.parameters(), lr=learning_rate_center_loss)

    elif optimizer == "adagrad":
        optimizer_model = torch.optim.Adagrad(model.parameters(),
                                              lr=learning_rate)
        optimizer_centerloss = torch.optim.Adagrad(
            criterion_centerloss.parameters(), lr=learning_rate_center_loss)

    elif optimizer == "rmsprop":
        optimizer_model = torch.optim.RMSprop(model.parameters(),
                                              lr=learning_rate)
        optimizer_centerloss = torch.optim.RMSprop(
            criterion_centerloss.parameters(), lr=learning_rate_center_loss)

    elif optimizer == "adam":
        optimizer_model = torch.optim.Adam(model.parameters(),
                                           lr=learning_rate)
        optimizer_centerloss = torch.optim.Adam(
            criterion_centerloss.parameters(), lr=learning_rate_center_loss)

    # Set learning rate decay scheduler
    learning_rate_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer=optimizer_model, milestones=[150, 225], gamma=0.1)

    # Optionally resume from a checkpoint
    if resume_path:

        if os.path.isfile(resume_path):
            print("\nLoading checkpoint {} ...".format(resume_path))

            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch']

            # In order to load state dict for optimizers correctly, model has to be loaded to gpu first
            if flag_train_multi_gpu:
                model.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint['model_state_dict'])

            optimizer_model.load_state_dict(
                checkpoint['optimizer_model_state_dict'])
            optimizer_centerloss.load_state_dict(
                checkpoint['optimizer_centerloss_state_dict'])
            learning_rate_scheduler.load_state_dict(
                checkpoint['learning_rate_scheduler_state_dict'])

            print(
                "\nCheckpoint loaded: start epoch from checkpoint = {}\nRunning for {} epochs.\n"
                .format(start_epoch, epochs - start_epoch))
        else:
            print(
                "WARNING: No checkpoint found at {}!\nTraining from scratch.".
                format(resume_path))

    # Start Training loop
    print(
        "\nTraining using cross entropy loss with center loss starting for {} epochs:\n"
        .format(epochs - start_epoch))

    total_time_start = time.time()
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs

    for epoch in range(start_epoch, end_epoch):
        epoch_time_start = time.time()

        flag_validate_lfw = (epoch +
                             1) % lfw_validation_epoch_interval == 0 or (
                                 epoch + 1) % epochs == 0
        train_loss_sum = 0
        validation_loss_sum = 0

        # Training the model
        model.train()
        learning_rate_scheduler.step()
        progress_bar = enumerate(tqdm(train_dataloader))

        for batch_index, (data, labels) in progress_bar:
            data, labels = data.cuda(), labels.cuda()

            # Forward pass
            if flag_train_multi_gpu:
                embedding, logits = model.module.forward_training(data)
            else:
                embedding, logits = model.forward_training(data)

            # Calculate losses
            cross_entropy_loss = criterion_crossentropy(
                logits.cuda(), labels.cuda())
            center_loss = criterion_centerloss(embedding, labels)
            loss = (center_loss * center_loss_weight) + cross_entropy_loss

            # Backward pass
            optimizer_centerloss.zero_grad()
            optimizer_model.zero_grad()
            loss.backward()
            optimizer_centerloss.step()
            optimizer_model.step()

            # Remove center_loss_weight impact on the learning of center vectors
            for param in criterion_centerloss.parameters():
                param.grad.data *= (1. / center_loss_weight)

            # Update training loss sum
            train_loss_sum += loss.item() * data.size(0)

        # Validating the model
        model.eval()
        correct, total = 0, 0

        with torch.no_grad():

            progress_bar = enumerate(tqdm(validation_dataloader))

            for batch_index, (data, labels) in progress_bar:

                data, labels = data.cuda(), labels.cuda()

                # Forward pass
                if flag_train_multi_gpu:
                    embedding, logits = model.module.forward_training(data)
                else:
                    embedding, logits = model.forward_training(data)

                # Calculate losses
                cross_entropy_loss = criterion_crossentropy(
                    logits.cuda(), labels.cuda())
                center_loss = criterion_centerloss(embedding, labels)
                loss = (center_loss * center_loss_weight) + cross_entropy_loss

                # Update average validation loss
                validation_loss_sum += loss.item() * data.size(0)

                # Calculate training performance metrics
                predictions = logits.data.max(1)[1]
                total += labels.size(0)
                correct += (predictions == labels.data).sum()

        # Calculate average losses in epoch
        avg_train_loss = train_loss_sum / len(train_dataloader.dataset)
        avg_validation_loss = validation_loss_sum / len(
            validation_dataloader.dataset)

        # Calculate training performance statistics in epoch
        classification_accuracy = correct * 100. / total
        classification_error = 100. - classification_accuracy

        epoch_time_end = time.time()

        # Print training and validation statistics and add to log
        print(
            'Epoch {}:\t Average Training Loss: {:.4f}\tAverage Validation Loss: {:.4f}\tClassification Accuracy: {:.2f}%\tClassification Error: {:.2f}%\tEpoch Time: {:.3f} hours'
            .format(epoch + 1, avg_train_loss, avg_validation_loss,
                    classification_accuracy, classification_error,
                    (epoch_time_end - epoch_time_start) / 3600))
        with open('logs/{}_log_center.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch + 1, avg_train_loss, avg_validation_loss,
                classification_accuracy.item(),
                classification_error.item()
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        try:
            # Plot plot for Cross Entropy Loss and Center Loss on training and validation sets
            plot_training_validation_losses_center(
                log_dir="logs/{}_log_center.txt".format(model_architecture),
                epochs=epochs,
                figure_name="plots/training_validation_losses_{}_center.png".
                format(model_architecture))
        except Exception as e:
            print(e)

        # Validating on LFW dataset using KFold based on Euclidean distance metric
        if flag_validate_lfw:

            model.eval()
            with torch.no_grad():
                l2_distance = PairwiseDistance(2).cuda()
                distances, labels = [], []

                print("Validating on LFW! ...")
                progress_bar = enumerate(tqdm(lfw_dataloader))

                for batch_index, (data_a, data_b, label) in progress_bar:
                    data_a, data_b, label = data_a.cuda(), data_b.cuda(
                    ), label.cuda()

                    output_a, output_b = model(data_a), model(data_b)
                    distance = l2_distance.forward(
                        output_a, output_b)  # Euclidean distance

                    distances.append(distance.cpu().detach().numpy())
                    labels.append(label.cpu().detach().numpy())

                labels = np.array(
                    [sublabel for label in labels for sublabel in label])
                distances = np.array([
                    subdist for distance in distances for subdist in distance
                ])

                true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \
                    tar, far = evaluate_lfw(
                        distances=distances,
                        labels=labels
                     )

                # Print statistics and add to log
                print(
                    "Accuracy on LFW: {:.4f}+-{:.4f}\tPrecision {:.4f}+-{:.4f}\tRecall {:.4f}+-{:.4f}\tROC Area Under Curve: {:.4f}\tBest distance threshold: {:.2f}+-{:.2f}\tTAR: {:.4f}+-{:.4f} @ FAR: {:.4f}"
                    .format(np.mean(accuracy), np.std(accuracy),
                            np.mean(precision), np.std(precision),
                            np.mean(recall), np.std(recall), roc_auc,
                            np.mean(best_distances), np.std(best_distances),
                            np.mean(tar), np.std(tar), np.mean(far)))
                with open(
                        'logs/lfw_{}_log_center.txt'.format(
                            model_architecture), 'a') as f:
                    val_list = [
                        epoch + 1,
                        np.mean(accuracy),
                        np.std(accuracy),
                        np.mean(precision),
                        np.std(precision),
                        np.mean(recall),
                        np.std(recall), roc_auc,
                        np.mean(best_distances),
                        np.std(best_distances),
                        np.mean(tar)
                    ]
                    log = '\t'.join(str(value) for value in val_list)
                    f.writelines(log + '\n')

            try:
                # Plot ROC curve
                plot_roc_lfw(
                    false_positive_rate=false_positive_rate,
                    true_positive_rate=true_positive_rate,
                    figure_name="plots/roc_plots/roc_{}_epoch_{}_center.png".
                    format(model_architecture, epoch + 1))
                # Plot LFW accuracies plot
                plot_accuracy_lfw(
                    log_dir="logs/lfw_{}_log_center.txt".format(
                        model_architecture),
                    epochs=epochs,
                    figure_name="plots/lfw_accuracies_{}_center.png".format(
                        model_architecture))
            except Exception as e:
                print(e)

        # Save model checkpoint
        state = {
            'epoch':
            epoch + 1,
            'num_classes':
            num_classes,
            'embedding_dimension':
            embedding_dimension,
            'batch_size_training':
            batch_size,
            'model_state_dict':
            model.state_dict(),
            'model_architecture':
            model_architecture,
            'optimizer_model_state_dict':
            optimizer_model.state_dict(),
            'optimizer_centerloss_state_dict':
            optimizer_centerloss.state_dict(),
            'learning_rate_scheduler_state_dict':
            learning_rate_scheduler.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_lfw:
            state['best_distance_threshold'] = np.mean(best_distances)

        # Save model checkpoint
        torch.save(
            state,
            'Model_training_checkpoints/model_{}_center_epoch_{}.pt'.format(
                model_architecture, epoch + 1))

    # Training loop end
    total_time_end = time.time()
    total_time_elapsed = total_time_end - total_time_start
    print("\nTraining finished: total time elapsed: {:.2f} hours.".format(
        total_time_elapsed / 3600))
def main():
    dataroot = args.dataroot
    lfw_dataroot = args.lfw
    dataset_csv = args.dataset_csv
    lfw_batch_size = args.lfw_batch_size
    lfw_validation_epoch_interval = args.lfw_validation_epoch_interval
    model_architecture = args.model
    epochs = args.epochs
    training_triplets_path = args.training_triplets_path
    num_triplets_train = args.num_triplets_train
    resume_path = args.resume_path
    batch_size = args.batch_size
    num_workers = args.num_workers
    embedding_dimension = args.embedding_dim
    pretrained = args.pretrained
    optimizer = args.optimizer
    learning_rate = args.lr
    margin = args.margin
    start_epoch = 0

    # Define image data pre-processing transforms
    #   ToTensor() normalizes pixel values between [0, 1]
    #   Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalizes pixel values between [-1, 1]

    #  Size 182x182 RGB image -> Center crop size 160x160 RGB image for more model generalization
    data_transforms = transforms.Compose([
        transforms.RandomCrop(size=160),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # Size 160x160 RGB image
    lfw_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Set dataloaders
    train_dataloader = torch.utils.data.DataLoader(dataset=TripletFaceDataset(
        root_dir=dataroot,
        csv_name=dataset_csv,
        num_triplets=num_triplets_train,
        training_triplets_path=training_triplets_path,
        transform=data_transforms),
                                                   batch_size=batch_size,
                                                   num_workers=num_workers,
                                                   shuffle=False)

    lfw_dataloader = torch.utils.data.DataLoader(dataset=LFWDataset(
        dir=lfw_dataroot,
        pairs_path='datasets/LFW_pairs.txt',
        transform=lfw_transforms),
                                                 batch_size=lfw_batch_size,
                                                 num_workers=num_workers,
                                                 shuffle=False)

    # Instantiate model
    if model_architecture == "resnet18":
        model = Resnet18Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet34":
        model = Resnet34Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet50":
        model = Resnet50Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet101":
        model = Resnet101Triplet(embedding_dimension=embedding_dimension,
                                 pretrained=pretrained)
    elif model_architecture == "inceptionresnetv2":
        model = InceptionResnetV2Triplet(
            embedding_dimension=embedding_dimension, pretrained=pretrained)
    print("Using {} model architecture.".format(model_architecture))

    # Load model to GPU or multiple GPUs if available
    flag_train_gpu = torch.cuda.is_available()
    flag_train_multi_gpu = False

    if flag_train_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model.cuda()
        flag_train_multi_gpu = True
        print('Using multi-gpu training.')
    elif flag_train_gpu and torch.cuda.device_count() == 1:
        model.cuda()
        print('Using single-gpu training.')

    # Set optimizers
    if optimizer == "sgd":
        optimizer_model = torch.optim.SGD(model.parameters(), lr=learning_rate)

    elif optimizer == "adagrad":
        optimizer_model = torch.optim.Adagrad(model.parameters(),
                                              lr=learning_rate)

    elif optimizer == "rmsprop":
        optimizer_model = torch.optim.RMSprop(model.parameters(),
                                              lr=learning_rate)

    elif optimizer == "adam":
        optimizer_model = torch.optim.Adam(model.parameters(),
                                           lr=learning_rate)

    # Optionally resume from a checkpoint
    if resume_path:

        if os.path.isfile(resume_path):
            print("\nLoading checkpoint {} ...".format(resume_path))

            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch']

            # In order to load state dict for optimizers correctly, model has to be loaded to gpu first
            if flag_train_multi_gpu:
                model.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint['model_state_dict'])

            optimizer_model.load_state_dict(
                checkpoint['optimizer_model_state_dict'])

            print(
                "\nCheckpoint loaded: start epoch from checkpoint = {}\nRunning for {} epochs.\n"
                .format(start_epoch, epochs - start_epoch))
        else:
            print(
                "WARNING: No checkpoint found at {}!\nTraining from scratch.".
                format(resume_path))

    # Start Training loop
    print(
        "\nTraining using triplet loss on {} triplets starting for {} epochs:\n"
        .format(num_triplets_train, epochs - start_epoch))

    total_time_start = time.time()
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    l2_distance = PairwiseDistance(2).cuda()

    for epoch in range(start_epoch, end_epoch):
        epoch_time_start = time.time()

        flag_validate_lfw = (epoch +
                             1) % lfw_validation_epoch_interval == 0 or (
                                 epoch + 1) % epochs == 0
        triplet_loss_sum = 0
        num_valid_training_triplets = 0

        # Training pass
        model.train()
        progress_bar = enumerate(tqdm(train_dataloader))

        for batch_idx, (batch_sample) in progress_bar:

            anc_img = batch_sample['anc_img'].cuda()
            pos_img = batch_sample['pos_img'].cuda()
            neg_img = batch_sample['neg_img'].cuda()

            # Forward pass - compute embeddings
            anc_embedding, pos_embedding, neg_embedding = model(
                anc_img), model(pos_img), model(neg_img)

            # Forward pass - choose hard negatives only for training
            pos_dist = l2_distance.forward(anc_embedding, pos_embedding)
            neg_dist = l2_distance.forward(anc_embedding, neg_embedding)

            all = (neg_dist - pos_dist < margin).cpu().numpy().flatten()

            hard_triplets = np.where(all == 1)
            if len(hard_triplets[0]) == 0:
                continue

            anc_hard_embedding = anc_embedding[hard_triplets].cuda()
            pos_hard_embedding = pos_embedding[hard_triplets].cuda()
            neg_hard_embedding = neg_embedding[hard_triplets].cuda()

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_hard_embedding,
                positive=pos_hard_embedding,
                negative=neg_hard_embedding).cuda()

            # Calculating loss
            triplet_loss_sum += triplet_loss.item()
            num_valid_training_triplets += len(anc_hard_embedding)

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

        # Model only trains on hard negative triplets
        avg_triplet_loss = 0 if (
            num_valid_training_triplets
            == 0) else triplet_loss_sum / num_valid_training_triplets
        epoch_time_end = time.time()

        # Print training statistics and add to log
        print(
            'Epoch {}:\tAverage Triplet Loss: {:.4f}\tEpoch Time: {:.3f} hours\tNumber of valid training triplets in epoch: {}'
            .format(epoch + 1, avg_triplet_loss,
                    (epoch_time_end - epoch_time_start) / 3600,
                    num_valid_training_triplets))
        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch + 1, avg_triplet_loss, num_valid_training_triplets
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        try:
            # Plot Triplet losses plot
            plot_triplet_losses(
                log_dir="logs/{}_log_triplet.txt".format(model_architecture),
                epochs=epochs,
                figure_name="plots/triplet_losses_{}.png".format(
                    model_architecture))
        except Exception as e:
            print(e)

        # Evaluation pass on LFW dataset
        if flag_validate_lfw:

            model.eval()
            with torch.no_grad():
                distances, labels = [], []

                print("Validating on LFW! ...")
                progress_bar = enumerate(tqdm(lfw_dataloader))

                for batch_index, (data_a, data_b, label) in progress_bar:
                    data_a, data_b, label = data_a.cuda(), data_b.cuda(
                    ), label.cuda()

                    output_a, output_b = model(data_a), model(data_b)
                    distance = l2_distance.forward(
                        output_a, output_b)  # Euclidean distance

                    distances.append(distance.cpu().detach().numpy())
                    labels.append(label.cpu().detach().numpy())

                labels = np.array(
                    [sublabel for label in labels for sublabel in label])
                distances = np.array([
                    subdist for distance in distances for subdist in distance
                ])

                true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \
                    tar, far = evaluate_lfw(
                        distances=distances,
                        labels=labels
                    )

                # Print statistics and add to log
                print(
                    "Accuracy on LFW: {:.4f}+-{:.4f}\tPrecision {:.4f}+-{:.4f}\tRecall {:.4f}+-{:.4f}\tROC Area Under Curve: {:.4f}\tBest distance threshold: {:.2f}+-{:.2f}\tTAR: {:.4f}+-{:.4f} @ FAR: {:.4f}"
                    .format(np.mean(accuracy), np.std(accuracy),
                            np.mean(precision), np.std(precision),
                            np.mean(recall), np.std(recall), roc_auc,
                            np.mean(best_distances), np.std(best_distances),
                            np.mean(tar), np.std(tar), np.mean(far)))
                with open(
                        'logs/lfw_{}_log_triplet.txt'.format(
                            model_architecture), 'a') as f:
                    val_list = [
                        epoch + 1,
                        np.mean(accuracy),
                        np.std(accuracy),
                        np.mean(precision),
                        np.std(precision),
                        np.mean(recall),
                        np.std(recall), roc_auc,
                        np.mean(best_distances),
                        np.std(best_distances),
                        np.mean(tar)
                    ]
                    log = '\t'.join(str(value) for value in val_list)
                    f.writelines(log + '\n')

            try:
                # Plot ROC curve
                plot_roc_lfw(
                    false_positive_rate=false_positive_rate,
                    true_positive_rate=true_positive_rate,
                    figure_name="plots/roc_plots/roc_{}_epoch_{}_triplet.png".
                    format(model_architecture, epoch + 1))
                # Plot LFW accuracies plot
                plot_accuracy_lfw(
                    log_dir="logs/lfw_{}_log_triplet.txt".format(
                        model_architecture),
                    epochs=epochs,
                    figure_name="plots/lfw_accuracies_{}_triplet.png".format(
                        model_architecture))
            except Exception as e:
                print(e)

        # Save model checkpoint
        state = {
            'epoch': epoch + 1,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_lfw:
            state['best_distance_threshold'] = np.mean(best_distances)

        # Save model checkpoint
        torch.save(
            state,
            'Model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format(
                model_architecture, epoch + 1))

    # Training loop end
    total_time_end = time.time()
    total_time_elapsed = total_time_end - total_time_start
    print("\nTraining finished: total time elapsed: {:.2f} hours.".format(
        total_time_elapsed / 3600))