Esempio n. 1
0
    def test(self, heatmap=True):
        """
        Test the model on the held-out test data.
        """
        # load the best checkpoint
        self.load_checkpoint(best=True)
        self.model.eval()
        with torch.no_grad():

            losses = []
            accuracies = []

            confusion_matrix = torch.zeros(self.num_classes, self.num_classes)
            heat_points = [[[] for i in range(self.num_glimpses)]
                           for j in range(self.num_digits + 1)]
            heat_mean = torch.zeros(self.num_digits, *self.im_size)

            var_means = torch.zeros(self.num_glimpses, len(self.test_loader))
            loss_means = torch.zeros(self.num_glimpses, len(self.test_loader))
            count_hits = torch.zeros(self.num_glimpses, 2)
            var_dist = torch.zeros(len(self.test_loader.dataset),
                                   self.num_glimpses, self.im_size[-1],
                                   self.im_size[-1])
            err_dist = torch.zeros(len(self.test_loader.dataset),
                                   self.num_glimpses, self.im_size[-1],
                                   self.im_size[-1])
            labels = torch.zeros(len(self.test_loader.dataset))

            dist = PairwiseDistance(2)
            dists = torch.zeros(len(self.test_loader.dataset),
                                self.num_glimpses)

            run_idx = 0
            for i, (_, x, y) in enumerate(self.test_loader):
                if self.use_gpu:
                    x = x.cuda()
                    y = y.cuda()

                true = y.detach().argmax(1)
                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                s_t, _, l_t = self.reset()

                targets = torch.zeros(self.num_glimpses, 3, *x.shape[2:])
                glimpses = torch.zeros(self.num_glimpses, 3, *x.shape[2:])
                means = torch.zeros(self.num_glimpses, 3, *x.shape[2:])
                vars = torch.zeros(self.num_glimpses, 3, *x.shape[2:])

                locs = torch.zeros(self.batch_size, self.num_glimpses, 2)
                acc_loss = 0
                sub_dir = self.plot_dir / f"{i:06d}_dir"
                if not sub_dir.exists() and i < 10:
                    sub_dir.mkdir(parents=True)

                for t in range(self.num_glimpses):
                    # forward pass through model
                    locs[:, t] = l_t.clone().detach()
                    s_t, l_pred, r_t, out_t, mu, logvar = self.model(
                        x,
                        l_t,
                        s_t,
                        samples=self.samples,
                        classify=self.classify,
                    )
                    draw = x[0].expand(3, -1, -1)
                    extent = self.patch_size
                    for patch in range(self.num_patches):
                        draw = draw_glimpse(
                            draw,
                            l_t[0],
                            extent=extent,
                        )
                        extent *= self.glimpse_scale

                    targets[t] = draw
                    glimpses[t] = self.model.sensor.glimpse[0][0].expand(
                        3, -1, -1)
                    means[t] = r_t[0][0].expand(3, -1, -1)
                    vars[t] = convert_heatmap(
                        r_t.var(0)[0].squeeze().cpu().numpy())

                    if t == 0:
                        var_first = r_t.var(0).squeeze().cpu().numpy()
                    var_last = r_t.var(0).squeeze().cpu().numpy()

                    loc_prev = loc = denormalize(x.size(-1), l_t)
                    l_t = get_loc_target(r_t.var(0),
                                         max=self.use_amax).type(l_t.type())
                    loc = denormalize(x.size(-1), l_t)
                    dists[run_idx:run_idx + len(x),
                          t] = dist.forward(loc_prev.float(),
                                            loc.float()).detach().cpu()
                    var_dist[run_idx:run_idx + len(x),
                             t] = r_t.var(0).detach().cpu().squeeze()

                    temp_loss = torch.zeros(self.batch_size,
                                            *err_dist.shape[2:])
                    for s in range(self.samples):
                        temp_loss += F.binary_cross_entropy(
                            r_t[s], x, reduction='none').detach().cpu(
                            ).squeeze() / self.samples
                    err_dist[run_idx:run_idx + len(x), t] = temp_loss

                    for k, l in enumerate(loc):
                        if x.squeeze()[k][l[0], l[1]] > 0:
                            count_hits[t, 0] += 1
                        else:
                            count_hits[t, 1] += 1

                    imgs = [targets[t], glimpses[t], means[t], vars[t]]
                    filename = sub_dir / f"{i:06d}_glimpse{t:01d}.png"
                    if i < 10:
                        torchvision.utils.save_image(
                            imgs,
                            filename,
                            nrow=1,
                            normalize=True,
                            pad_value=1,
                            padding=1,
                            scale_each=True,
                        )

                    var = r_t.var(0).reshape(self.batch_size, -1)
                    var_means[t, i] = var.max(-1)[0].mean(0)
                    loss_means[t, i] = temp_loss.mean()

                    for j in range(self.num_digits):
                        heat_points[j][t].extend(
                            denormalize(x.size(-1), l_t[true == j]).tolist())

                    heat_points[-1][t].extend(loc.tolist())

                labels[run_idx:run_idx + len(true)] = true
                run_idx += len(x)
                for s in range(self.samples):
                    loss = F.mse_loss(r_t[s], x) / self.samples
                    acc_loss += loss.item()  # store

                if self.classify:
                    if self.num_classes == 1:
                        pred = out_t.detach().squeeze().round()
                        target = (y.argmax(1) == self.target_class).float()
                    else:
                        pred = out_t.detach().argmax(1)
                        target = true
                        for p, tr in zip(pred.view(-1), target.view(-1)):
                            confusion_matrix[tr.long(), p.long()] += 1

                    accuracies.append(
                        torch.sum(pred == target).float().item() / len(y))

                for j in range(self.num_digits):
                    if len(x[true == j]):
                        heat_mean[j] += x[true == j].mean(0).cpu()

                losses.append(acc_loss)
                imgs = [targets, glimpses, means, vars]

                for im in imgs:
                    im = (im - im.min()) / (im.max() - im.min())
                # store images + reconstructions of largest scale
                filename = self.plot_dir / f"{i:06d}.png"
                if i < 10:
                    torchvision.utils.save_image(
                        torch.cat(imgs, 0),
                        filename,
                        nrow=6,
                        normalize=False,
                        pad_value=1,
                        padding=3,
                        scale_each=False,
                    )

            pkl.dump(dists, open(self.file_dir / 'dists.pkl', 'wb'))
            pkl.dump(var_dist, open(self.file_dir / 'var_dist.pkl', 'wb'))
            pkl.dump(err_dist, open(self.file_dir / 'err_dist.pkl', 'wb'))
            pkl.dump(heat_points, open(self.file_dir / 'locs.pkl', 'wb'))
            pkl.dump(labels, open(self.file_dir / 'labels.pkl', 'wb'))
            pkl.dump(heat_mean, open(self.file_dir / 'mean.pkl', 'wb'))
            print(count_hits[:, 0] / count_hits.sum(1))

            sn.set(font="serif",
                   font_scale=2,
                   context='paper',
                   style='dark',
                   rc={"lines.linewidth": 2.5})

            errs = err_dist.view(len(self.test_loader.dataset),
                                 self.num_glimpses, -1).cpu().numpy()
            vars = var_dist.view(len(self.test_loader.dataset),
                                 self.num_glimpses, -1).cpu().numpy()
            f, ax = plt.subplots(2, 1, sharex=True, figsize=(9, 12))
            ax[0].plot(range(self.num_glimpses),
                       errs.mean((0, -1)),
                       marker='o',
                       markersize=10,
                       c=PLOT_COLOR)
            ax[0].set_ylabel('Prediction Error')
            ax[1].plot(range(self.num_glimpses),
                       vars.max(-1).mean(0),
                       marker='o',
                       markersize=10,
                       c=PLOT_COLOR)
            ax[1].set_xlabel('Saccade number')
            ax[1].set_ylabel('Uncertainty')
            f.tight_layout()
            f.savefig(self.plot_dir / f"comb_var_err.pdf",
                      bbox_inches='tight',
                      pad_inches=0,
                      dpi=600)

            print("#######################")
            if self.classify:
                print(confusion_matrix)
                print(confusion_matrix.diag() / confusion_matrix.sum(1))
                plot_confusion_matrix(confusion_matrix,
                                      self.plot_dir / f"confusion_matrix.png")
                pkl.dump(confusion_matrix,
                         open(self.file_dir / 'conf_matrix.pkl', 'wb'))

            #create heatmaps
            flatten = lambda l, i: [
                item for sublist in l for item in sublist[i]
            ]
            for i in range(self.num_digits):
                img = array2img(heat_mean[i])
                points = np.array(heat_points[i])

                first = points[:3].reshape(-1, 2)
                heatmapper = Heatmapper(
                    point_diameter=1,  # the size of each point to be drawn
                    point_strength=
                    0.3,  # the strength, between 0 and 1, of each point to be drawn
                    opacity=0.85,
                )
                heatmap = heatmapper.heatmap_on_img(first, img)
                heatmap.save(self.plot_dir / f"heatmap_bef{i}.png")

                last = points[3:-1].reshape(-1, 2)
                heatmap = heatmapper.heatmap_on_img(last, img)
                heatmap.save(self.plot_dir / f"heatmap_aft{i}.png")

                for j in range(self.num_glimpses):
                    heatmap = heatmapper.heatmap_on_img(heat_points[i][j], img)
                    heatmap.save(self.plot_dir /
                                 f"heatmap_class{i}_glimpse{j}.png")

            self.logger.info(
                f"[*] Test loss: {np.mean(losses)}, Test accuracy: {np.mean(accuracies)}"
            )
            return np.mean(losses)
Esempio n. 2
0
              end='')

        #f, axarr = plt.subplots(1,3)

        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        anc_fv = model(data['anc_img'].cuda())
        pos_fv = model(data['pos_img'].cuda())
        neg_fv = model(data['neg_img'].cuda())

        #print("Batch Size :",anc_fv.shape[0])

        pos_dis = l2_distance.forward(anc_fv, pos_fv)
        neg_dis = l2_distance.forward(anc_fv, neg_fv)

        all = (pos_dis < neg_dis).cpu().numpy().flatten()

        losses = torch.clamp(pos_dis[all] - neg_dis[all] + margin, min=0.0)

        loss = torch.mean(losses)

        #print(losses)
        print(" Len :", len(pos_dis[all]), end='')
        print(" Loss :", loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Esempio n. 3
0
def main():
    dataroot = args.dataroot
    apd_dataroot = args.apd
    dataset_csv = args.dataset_csv
    apd_batch_size = args.apd_batch_size
    apd_validation_epoch_interval = args.apd_validation_epoch_interval
    model_architecture = args.model
    epochs = args.epochs
    training_triplets_path = args.training_triplets_path
    num_triplets_train = args.num_triplets_train
    resume_path = args.resume_path
    batch_size = args.batch_size
    num_workers = args.num_workers
    embedding_dimension = args.embedding_dim
    pretrained = args.pretrained
    optimizer = args.optimizer
    learning_rate = args.lr
    margin = args.margin
    start_epoch = 0

    # Define image data pre-processing transforms
    #   ToTensor() normalizes pixel values between [0, 1]
    #   Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalizes pixel values between [-1, 1]

    #  Size 182x182 RGB image -> Center crop size 160x160 RGB image for more model generalization
    data_transforms = transforms.Compose([
        #transforms.RandomCrop(size=10),
        transforms.Resize(size=(50, 50)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # Size 160x160 RGB image
    apd_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    logging.info('prepare VGGNet dataloader...')
    # Set dataloaders
    train_dataloader = torch.utils.data.DataLoader(dataset=TripletFaceDataset(
        root_dir=dataroot,
        csv_name=dataset_csv,
        num_triplets=num_triplets_train,
        training_triplets_path=training_triplets_path,
        transform=data_transforms),
                                                   batch_size=batch_size,
                                                   num_workers=num_workers,
                                                   shuffle=False)

    logging.info('prepare APD dataset')
    apd_dataroot = '../APD'
    apdDataset = APDDataset(
        directory=apd_dataroot,
        #pairs_path='negative_pairs.txt',
        #pairs_path='positive_pairs.txt',
        pairs_path='full.txt',
        transform=apd_transforms)
    print('apdDataset', apdDataset)

    logging.info('prepare apd loader')
    apd_dataloader = torch.utils.data.DataLoader(dataset=apdDataset,
                                                 batch_size=apd_batch_size,
                                                 num_workers=num_workers,
                                                 shuffle=True)
    print('apd_dataloader', apd_dataloader)

    logging.info('prepare model')
    # Instantiate model
    if model_architecture == "resnet18":
        model = Resnet18Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet34":
        model = Resnet34Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet50":
        model = Resnet50Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet101":
        model = Resnet101Triplet(embedding_dimension=embedding_dimension,
                                 pretrained=pretrained)
    elif model_architecture == "inceptionresnetv2":
        model = InceptionResnetV2Triplet(
            embedding_dimension=embedding_dimension, pretrained=pretrained)
    print("Using {} model architecture.".format(model_architecture))

    # Load model to GPU or multiple GPUs if available
    flag_train_gpu = torch.cuda.is_available()
    flag_train_multi_gpu = False

    if flag_train_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model.cuda()
        flag_train_multi_gpu = True
        print('Using multi-gpu training.')
    elif flag_train_gpu and torch.cuda.device_count() == 1:
        model.cuda()
        print('Using single-gpu training.')

    logging.info('set optimizer')
    # Set optimizers
    if optimizer == "sgd":
        optimizer_model = torch.optim.SGD(model.parameters(), lr=learning_rate)

    elif optimizer == "adagrad":
        optimizer_model = torch.optim.Adagrad(model.parameters(),
                                              lr=learning_rate)

    elif optimizer == "rmsprop":
        optimizer_model = torch.optim.RMSprop(model.parameters(),
                                              lr=learning_rate)

    elif optimizer == "adam":
        optimizer_model = torch.optim.Adam(model.parameters(),
                                           lr=learning_rate)

    # Optionally resume from a checkpoint
    if resume_path:

        if os.path.isfile(resume_path):
            print("\nLoading checkpoint {} ...".format(resume_path))

            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch']

            # In order to load state dict for optimizers correctly, model has to be loaded to gpu first
            if flag_train_multi_gpu:
                model.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint['model_state_dict'])

            optimizer_model.load_state_dict(
                checkpoint['optimizer_model_state_dict'])

            print(
                "\nCheckpoint loaded: start epoch from checkpoint = {}\nRunning for {} epochs.\n"
                .format(start_epoch, epochs - start_epoch))
        else:
            print(
                "WARNING: No checkpoint found at {}!\nTraining from scratch.".
                format(resume_path))

    # Start Training loop
    print(
        "\nTraining using triplet loss on {} triplets starting for {} epochs:\n"
        .format(num_triplets_train, epochs - start_epoch))

    total_time_start = time.time()
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    l2_distance = PairwiseDistance(2).cuda()

    BATCH_NUM = len(train_dataloader.dataset) / batch_size
    for epoch in range(start_epoch, end_epoch):
        epoch_time_start = time.time()

        #flag_validate_apd = (epoch + 1) % lfw_validation_epoch_interval == 0 or (epoch + 1) % epochs == 0
        flag_validate_apd = True

        triplet_loss_sum = 0
        num_valid_training_triplets = 0

        # Training pass
        model.train()
        #progress_bar = enumerate(tqdm(train_dataloader))
        progress_bar = enumerate(train_dataloader)

        for batch_idx, (batch_sample) in progress_bar:
            #break
            logging.info("epoch:{}/{} batch_idx:{}/{}".format(
                epoch, end_epoch, batch_idx, BATCH_NUM))
            #print('batch_idx',batch_idx)

            anc_img = batch_sample['anc_img'].cuda()
            pos_img = batch_sample['pos_img'].cuda()
            neg_img = batch_sample['neg_img'].cuda()

            # Forward pass - compute embeddings
            anc_embedding, pos_embedding, neg_embedding = model(
                anc_img), model(pos_img), model(neg_img)

            # Forward pass - choose hard negatives only for training
            pos_dist = l2_distance.forward(anc_embedding, pos_embedding)
            neg_dist = l2_distance.forward(anc_embedding, neg_embedding)

            all = (neg_dist - pos_dist < margin).cpu().numpy().flatten()

            hard_triplets = np.where(all == 1)
            if len(hard_triplets[0]) == 0:
                continue

            anc_hard_embedding = anc_embedding[hard_triplets].cuda()
            pos_hard_embedding = pos_embedding[hard_triplets].cuda()
            neg_hard_embedding = neg_embedding[hard_triplets].cuda()

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_hard_embedding,
                positive=pos_hard_embedding,
                negative=neg_hard_embedding).cuda()

            # Calculating loss
            triplet_loss_sum += triplet_loss.item()
            num_valid_training_triplets += len(anc_hard_embedding)

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

            #if batch_idx == 20:
            #    break

        # Model only trains on hard negative triplets
        avg_triplet_loss = 0 if (
            num_valid_training_triplets
            == 0) else triplet_loss_sum / num_valid_training_triplets
        epoch_time_end = time.time()

        # Print training statistics and add to log
        print(
            'Epoch {}:\tAverage Triplet Loss: {:.4f}\tEpoch Time: {:.3f} hours\tNumber of valid training triplets in epoch: {}'
            .format(epoch + 1, avg_triplet_loss,
                    (epoch_time_end - epoch_time_start) / 3600,
                    num_valid_training_triplets))

        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch + 1, avg_triplet_loss, num_valid_training_triplets
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        try:
            # Plot Triplet losses plot
            plot_triplet_losses(
                log_dir="logs/{}_log_triplet.txt".format(model_architecture),
                epochs=epochs,
                figure_name="plots/triplet_losses_{}.png".format(
                    model_architecture))
        except Exception as e:
            print(e)

        # Evaluation pass on LFW dataset
        if flag_validate_apd:

            model.eval()
            with torch.no_grad():
                distances, labels = [], []

                print("Validating on APD! ...")
                progress_bar = enumerate(tqdm(apd_dataloader))

                for batch_index, (data_a, data_b, label) in progress_bar:
                    data_a, data_b, label = data_a.cuda(), data_b.cuda(
                    ), label.cuda()
                    #print('data_a', data_a.shape)
                    #print('data_b', data_b.shape)
                    #print('label', label)

                    output_a, output_b = model(data_a), model(data_b)
                    #print('output_a',output_a)
                    #print('output_b',output_b)

                    distance = l2_distance.forward(
                        output_a, output_b)  # Euclidean distance
                    #print('distance',distance)

                    distances.append(distance.cpu().detach().numpy())
                    labels.append(label.cpu().detach().numpy())
                    #if batch_index == 20:
                    #    break

                labels = np.array(
                    [sublabel for label in labels for sublabel in label])
                distances = np.array([
                    subdist for distance in distances for subdist in distance
                ])
                print('len(labels)', len(labels))
                print('len(distances)', len(distances))


                true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \
                    tar, far = evaluate_lfw(
                        distances=distances,
                        labels=labels
                    )

                # Print statistics and add to log
                print(
                    "Accuracy on APD: {:.4f}+-{:.4f}\tPrecision {:.4f}+-{:.4f}\tRecall {:.4f}+-{:.4f}\tROC Area Under Curve: {:.4f}\tBest distance threshold: {:.2f}+-{:.2f}\tTAR: {:.4f}+-{:.4f} @ FAR: {:.4f}"
                    .format(np.mean(accuracy), np.std(accuracy),
                            np.mean(precision), np.std(precision),
                            np.mean(recall), np.std(recall), roc_auc,
                            np.mean(best_distances), np.std(best_distances),
                            np.mean(tar), np.std(tar), np.mean(far)))

                with open(
                        'logs/lfw_{}_log_triplet.txt'.format(
                            model_architecture), 'a') as f:
                    val_list = [
                        epoch + 1,
                        np.mean(accuracy),
                        np.std(accuracy),
                        np.mean(precision),
                        np.std(precision),
                        np.mean(recall),
                        np.std(recall), roc_auc,
                        np.mean(best_distances),
                        np.std(best_distances),
                        np.mean(tar)
                    ]
                    log = '\t'.join(str(value) for value in val_list)
                    f.writelines(log + '\n')

            try:
                # Plot ROC curve
                plot_roc_lfw(
                    false_positive_rate=false_positive_rate,
                    true_positive_rate=true_positive_rate,
                    figure_name=
                    "plots/roc_plots_triplet/roc_{}_epoch_{}_triplet.png".
                    format(model_architecture, epoch + 1),
                    epochNum=epoch)
                # Plot LFW accuracies plot
                plot_accuracy_lfw(
                    log_dir="logs/lfw_{}_log_triplet.txt".format(
                        model_architecture),
                    epochs=epochs,
                    figure_name="plots/apd_accuracies_{}_triplet.png".format(
                        model_architecture))
            except Exception as e:
                print(e)

        # Save model checkpoint
        state = {
            'epoch': epoch + 1,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_apd:
            state['best_distance_threshold'] = np.mean(best_distances)

        # Save model checkpoint
        torch.save(
            state,
            'Model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format(
                model_architecture, epoch + 1))

    # Training loop end
    total_time_end = time.time()
    total_time_elapsed = total_time_end - total_time_start
    print("\nTraining finished: total time elapsed: {:.2f} hours.".format(
        total_time_elapsed / 3600))
Esempio n. 4
0
def main():
    lfw_dataroot = args.lfw
    model_path = args.model_path
    far_target = args.far_target

    checkpoint = torch.load(model_path)
    model = Resnet18Triplet(
        embedding_dimension=checkpoint['embedding_dimension'])
    model.load_state_dict(checkpoint['model_state_dict'])

    flag_gpu_available = torch.cuda.is_available()

    if flag_gpu_available:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    lfw_transforms = transforms.Compose([
        transforms.Resize(size=224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.6068, 0.4517, 0.3800],
                             std=[0.2492, 0.2173, 0.2082])
    ])

    lfw_dataloader = torch.utils.data.DataLoader(dataset=LFWDataset(
        dir=lfw_dataroot,
        pairs_path='../datasets/LFW_pairs.txt',
        transform=lfw_transforms),
                                                 batch_size=256,
                                                 num_workers=2,
                                                 shuffle=False)

    model.to(device)
    model = model.eval()

    with torch.no_grad():
        l2_distance = PairwiseDistance(p=2)
        distances, labels = [], []

        progress_bar = enumerate(tqdm(lfw_dataloader))

        for batch_index, (data_a, data_b, label) in progress_bar:
            data_a = data_a.cuda()
            data_b = data_b.cuda()

            output_a, output_b = model(data_a), model(data_b)
            distance = l2_distance.forward(output_a,
                                           output_b)  # Euclidean distance

            distances.append(distance.cpu().detach().numpy())
            labels.append(label.cpu().detach().numpy())

        labels = np.array([sublabel for label in labels for sublabel in label])
        distances = np.array(
            [subdist for distance in distances for subdist in distance])

        _, _, _, _, _, _, _, tar, far = evaluate_lfw(distances=distances,
                                                     labels=labels,
                                                     far_target=far_target)

        print("TAR: {:.4f}+-{:.4f} @ FAR: {:.4f}".format(
            np.mean(tar), np.std(tar), np.mean(far)))
def main():
    lfw_dataroot = args.lfw
    model_path = args.model_path
    far_target = args.far_target
    batch_size = args.batch_size

    flag_gpu_available = torch.cuda.is_available()

    if flag_gpu_available:
        device = torch.device("cuda")
        print('Using GPU')
    else:
        device = torch.device("cpu")
        print('Using CPU')

    checkpoint = torch.load(model_path, map_location=device)
    model = Resnet18Triplet(
        embedding_dimension=checkpoint['embedding_dimension'])
    model.load_state_dict(checkpoint['model_state_dict'])

    # desiredFaceWidth, height and desiredLeftEye work together to produce centered aligned faces
    # desiredLeftEye (x,y) says how the face moves in the heightxwidth window. bigger y moves the face down
    # bigger x zooms-out the face and severe values migh flip the face upside down, you wan it in range [0.2;0.36]
    desiredFaceHeight = 448  # set non-equal width, height so the tf_resize stretch the faces, select bigger than 224 values so we don't lose too much of the image quality
    desiredFaceWidth = 352
    desiredLeftEye = (0.28, 0.35)

    tf_align = transform_align(
        landmark_predictor_weight_path=landmark_predictor_weights,
        face_detector_path=face_detector_weights_path,
        desiredFaceWidth=desiredFaceWidth,
        desiredFaceHeight=desiredFaceHeight,
        desiredLeftEye=desiredLeftEye)

    lfw_transforms = transforms.Compose([
        tf_align,
        transforms.Resize(size=(140, 140)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.6071, 0.4609, 0.3944],
                             std=[0.2457, 0.2175, 0.2129])
    ])

    lfw_dataloader = torch.utils.data.DataLoader(
        dataset=LFWDataset(dir=lfw_dataroot,
                           pairs_path='../datasets/LFW_pairs.txt',
                           transform=lfw_transforms),
        batch_size=
        batch_size,  # default = 256; 160 - allows running under 2GB VRAM
        num_workers=2,
        shuffle=False)

    model.to(device)
    model = model.eval()

    with torch.no_grad():
        l2_distance = PairwiseDistance(p=2)
        distances, labels = [], []

        progress_bar = enumerate(tqdm(lfw_dataloader))

        for batch_index, (data_a, data_b, label) in progress_bar:
            data_a = data_a.to(device)  # data_a = data_a.cuda()
            data_b = data_b.to(device)  # data_b = data_b.cuda()

            output_a, output_b = model(data_a), model(data_b)
            distance = l2_distance.forward(output_a,
                                           output_b)  # Euclidean distance

            distances.append(distance.cpu().detach().numpy())
            labels.append(label.cpu().detach().numpy())

        labels = np.array([sublabel for label in labels for sublabel in label])
        distances = np.array(
            [subdist for distance in distances for subdist in distance])

        _, _, _, _, _, _, _, tar, far = evaluate_lfw(distances=distances,
                                                     labels=labels,
                                                     far_target=far_target)

        print("TAR: {:.4f}+-{:.4f} @ FAR: {:.4f}".format(
            np.mean(tar), np.std(tar), np.mean(far)))
Esempio n. 6
0
    print('不存在预训练模型!')

l2_distance = PairwiseDistance(2)
with torch.no_grad():  # 不传梯度了
    distances, labels = [], []
    progress_bar = enumerate(tqdm(test_dataloader))
    for batch_index, (data_a, data_b, label) in progress_bar:
        #for batch_index, (data_a, data_b, label) in enumerate(test_dataloader):
        # data_a, data_b, label这仨是一批的矩阵
        data_a = data_a.to(device)
        data_b = data_b.to(device)
        label = label.to(device)
        output_a, output_b = model(data_a), model(data_b)
        output_a = torch.div(output_a, torch.norm(output_a))
        output_b = torch.div(output_b, torch.norm(output_b))
        distance = l2_distance.forward(output_a, output_b)
        # 列表里套矩阵
        labels.append(label.cpu().detach().numpy())
        distances.append(distance.cpu().detach().numpy())
        #if batch_index >=3:
        #    break
    print("get all image's distance done")

    labels = np.array([sublabel for label in labels for sublabel in label])
    distances = np.array(
        [subdist for distance in distances for subdist in distance])
    true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \
    tar, far = evaluate_lfw(
        distances=distances,
        labels=labels,
        epoch='',
Esempio n. 7
0
def main():
    dataroot = args.dataroot
    lfw_dataroot = args.lfw
    dataset_csv = args.dataset_csv
    epochs = args.epochs
    iterations_per_epoch = args.iterations_per_epoch
    model_architecture = args.model_architecture
    pretrained = args.pretrained
    embedding_dimension = args.embedding_dimension
    num_human_identities_per_batch = args.num_human_identities_per_batch
    batch_size = args.batch_size
    lfw_batch_size = args.lfw_batch_size
    num_generate_triplets_processes = args.num_generate_triplets_processes
    resume_path = args.resume_path
    num_workers = args.num_workers
    optimizer = args.optimizer
    learning_rate = args.learning_rate
    margin = args.margin
    image_size = args.image_size
    use_semihard_negatives = args.use_semihard_negatives
    training_triplets_path = args.training_triplets_path
    flag_training_triplets_path = False
    start_epoch = 0

    if training_triplets_path is not None:
        flag_training_triplets_path = True  # Load triplets file for the first training epoch

    # Define image data pre-processing transforms
    #   ToTensor() normalizes pixel values between [0, 1]
    #   Normalize(mean=[0.6068, 0.4517, 0.3800], std=[0.2492, 0.2173, 0.2082]) normalizes pixel values to be mean
    #    of zero and standard deviation of 1 according to the calculated VGGFace2 with tightly-cropped faces
    #    dataset RGB channels' mean and std values by calculate_vggface2_rgb_mean_std.py in 'datasets' folder.
    data_transforms = transforms.Compose([
        transforms.Resize(size=image_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.6068, 0.4517, 0.3800],
                             std=[0.2492, 0.2173, 0.2082])
    ])

    lfw_transforms = transforms.Compose([
        transforms.Resize(size=image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.6068, 0.4517, 0.3800],
                             std=[0.2492, 0.2173, 0.2082])
    ])

    lfw_dataloader = torch.utils.data.DataLoader(dataset=LFWDataset(
        dir=lfw_dataroot,
        pairs_path='datasets/LFW_pairs.txt',
        transform=lfw_transforms),
                                                 batch_size=lfw_batch_size,
                                                 num_workers=num_workers,
                                                 shuffle=False)

    # Instantiate model
    model = set_model_architecture(model_architecture=model_architecture,
                                   pretrained=pretrained,
                                   embedding_dimension=embedding_dimension)

    # Load model to GPU or multiple GPUs if available
    model, flag_train_multi_gpu = set_model_gpu_mode(model)

    # Set optimizer
    optimizer_model = set_optimizer(optimizer=optimizer,
                                    model=model,
                                    learning_rate=learning_rate)

    # Resume from a model checkpoint
    if resume_path:
        if os.path.isfile(resume_path):
            print("Loading checkpoint {} ...".format(resume_path))
            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch'] + 1
            optimizer_model.load_state_dict(
                checkpoint['optimizer_model_state_dict'])

            # In order to load state dict for optimizers correctly, model has to be loaded to gpu first
            if flag_train_multi_gpu:
                model.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint['model_state_dict'])
            print("Checkpoint loaded: start epoch from checkpoint = {}".format(
                start_epoch))
        else:
            print(
                "WARNING: No checkpoint found at {}!\nTraining from scratch.".
                format(resume_path))

    if use_semihard_negatives:
        print("Using Semi-Hard negative triplet selection!")
    else:
        print("Using Hard negative triplet selection!")

    start_epoch = start_epoch

    print("Training using triplet loss starting for {} epochs:\n".format(
        epochs - start_epoch))

    for epoch in range(start_epoch, epochs):
        num_valid_training_triplets = 0
        l2_distance = PairwiseDistance(p=2)
        _training_triplets_path = None

        if flag_training_triplets_path:
            _training_triplets_path = training_triplets_path
            flag_training_triplets_path = False  # Only load triplets file for the first epoch

        # Re-instantiate training dataloader to generate a triplet list for this training epoch
        train_dataloader = torch.utils.data.DataLoader(
            dataset=TripletFaceDataset(
                root_dir=dataroot,
                csv_name=dataset_csv,
                num_triplets=iterations_per_epoch * batch_size,
                num_generate_triplets_processes=num_generate_triplets_processes,
                num_human_identities_per_batch=num_human_identities_per_batch,
                triplet_batch_size=batch_size,
                epoch=epoch,
                training_triplets_path=_training_triplets_path,
                transform=data_transforms),
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=
            False  # Shuffling for triplets with set amount of human identities per batch is not required
        )

        # Training pass
        model.train()
        progress_bar = enumerate(tqdm(train_dataloader))

        for batch_idx, (batch_sample) in progress_bar:

            # Forward pass - compute embeddings
            anc_imgs = batch_sample['anc_img']
            pos_imgs = batch_sample['pos_img']
            neg_imgs = batch_sample['neg_img']

            # Concatenate the input images into one tensor because doing multiple forward passes would create
            #  weird GPU memory allocation behaviours later on during training which would cause GPU Out of Memory
            #  issues
            all_imgs = torch.cat(
                (anc_imgs, pos_imgs,
                 neg_imgs))  # Must be a tuple of Torch Tensors

            anc_embeddings, pos_embeddings, neg_embeddings, model = forward_pass(
                imgs=all_imgs, model=model, batch_size=batch_size)

            pos_dists = l2_distance.forward(anc_embeddings, pos_embeddings)
            neg_dists = l2_distance.forward(anc_embeddings, neg_embeddings)

            if use_semihard_negatives:
                # Semi-Hard Negative triplet selection
                #  (negative_distance - positive_distance < margin) AND (positive_distance < negative_distance)
                #   Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L295
                first_condition = (neg_dists - pos_dists <
                                   margin).cpu().numpy().flatten()
                second_condition = (pos_dists <
                                    neg_dists).cpu().numpy().flatten()
                all = (np.logical_and(first_condition, second_condition))
                valid_triplets = np.where(all == 1)
            else:
                # Hard Negative triplet selection
                #  (negative_distance - positive_distance < margin)
                #   Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L296
                all = (neg_dists - pos_dists < margin).cpu().numpy().flatten()
                valid_triplets = np.where(all == 1)

            anc_valid_embeddings = anc_embeddings[valid_triplets]
            pos_valid_embeddings = pos_embeddings[valid_triplets]
            neg_valid_embeddings = neg_embeddings[valid_triplets]

            del anc_embeddings, pos_embeddings, neg_embeddings, pos_dists, neg_dists
            gc.collect()

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_valid_embeddings,
                positive=pos_valid_embeddings,
                negative=neg_valid_embeddings)

            # Calculating number of triplets that met the triplet selection method during the epoch
            num_valid_training_triplets += len(anc_valid_embeddings)

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

            # Clear some memory at end of training iteration
            del triplet_loss, anc_valid_embeddings, pos_valid_embeddings, neg_valid_embeddings
            gc.collect()

        # Print training statistics for epoch and add to log
        print(
            'Epoch {}:\tNumber of valid training triplets in epoch: {}'.format(
                epoch, num_valid_training_triplets))

        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [epoch, num_valid_training_triplets]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        # Evaluation pass on LFW dataset
        best_distances = validate_lfw(model=model,
                                      lfw_dataloader=lfw_dataloader,
                                      model_architecture=model_architecture,
                                      epoch=epoch,
                                      epochs=epochs)

        # Save model checkpoint
        state = {
            'epoch': epoch,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict(),
            'best_distance_threshold': np.mean(best_distances)
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # Save model checkpoint
        torch.save(
            state,
            'model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format(
                model_architecture, epoch))
Esempio n. 8
0
for epoch in range(0,5):
    triplet_loss_sum = 0
    num_valid_training_triplets = 0
    progress_bar = enumerate(tqdm(train_dataloader))
    for batch_idx, (batch_sample) in progress_bar:
        anc_image = batch_sample['anc_img'].view(-1, 3, 220, 220).cuda()
        pos_image = batch_sample['pos_img'].view(-1, 3, 220, 220).cuda()
        neg_image = batch_sample['neg_img'].view(-1, 3, 220, 220).cuda()

        anc_embedding = net(anc_image)
        pos_embedding = net(pos_image)
        neg_embedding = net(neg_image)

        
        pos_dist = l2_distance.forward(anc_embedding, pos_embedding)
        neg_dist = l2_distance.forward(anc_embedding, neg_embedding)
        print(pos_dist)
        print(neg_dist)
        print('\n')

        allcd = (neg_dist - pos_dist < margin).cpu().numpy().flatten()
        hard_triplets = np.where(allcd == 1)
        anc_hard_embedding = anc_embedding[hard_triplets].cuda()
        pos_hard_embedding = pos_embedding[hard_triplets].cuda()
        neg_hard_embedding = neg_embedding[hard_triplets].cuda()

        triplet_loss = TripletLoss(margin=margin).forward(
                    anchor=anc_hard_embedding,
                    positive=pos_hard_embedding,
                    negative=neg_hard_embedding
Esempio n. 9
0
def validate_lfw(model, lfw_dataloader, model_architecture, epoch, epochs):
    model.eval()
    with torch.no_grad():
        l2_distance = PairwiseDistance(p=2)
        distances, labels = [], []

        print("Validating on LFW! ...")
        progress_bar = enumerate(tqdm(lfw_dataloader))

        for batch_index, (data_a, data_b, label) in progress_bar:
            data_a = data_a.cuda()
            data_b = data_b.cuda()

            output_a, output_b = model(data_a), model(data_b)
            distance = l2_distance.forward(output_a,
                                           output_b)  # Euclidean distance

            distances.append(distance.cpu().detach().numpy())
            labels.append(label.cpu().detach().numpy())

        labels = np.array([sublabel for label in labels for sublabel in label])
        distances = np.array(
            [subdist for distance in distances for subdist in distance])

        true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \
        tar, far = evaluate_lfw(
            distances=distances,
            labels=labels,
            far_target=1e-3
        )
        # Print statistics and add to log
        print(
            "Accuracy on LFW: {:.4f}+-{:.4f}\tPrecision {:.4f}+-{:.4f}\tRecall {:.4f}+-{:.4f}\t"
            "ROC Area Under Curve: {:.4f}\tBest distance threshold: {:.2f}+-{:.2f}\t"
            "TAR: {:.4f}+-{:.4f} @ FAR: {:.4f}".format(np.mean(accuracy),
                                                       np.std(accuracy),
                                                       np.mean(precision),
                                                       np.std(precision),
                                                       np.mean(recall),
                                                       np.std(recall), roc_auc,
                                                       np.mean(best_distances),
                                                       np.std(best_distances),
                                                       np.mean(tar),
                                                       np.std(tar),
                                                       np.mean(far)))
        with open('logs/lfw_{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch,
                np.mean(accuracy),
                np.std(accuracy),
                np.mean(precision),
                np.std(precision),
                np.mean(recall),
                np.std(recall), roc_auc,
                np.mean(best_distances),
                np.std(best_distances),
                np.mean(tar)
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

    try:
        # Plot ROC curve
        plot_roc_lfw(
            false_positive_rate=false_positive_rate,
            true_positive_rate=true_positive_rate,
            figure_name="plots/roc_plots/roc_{}_epoch_{}_triplet.png".format(
                model_architecture, epoch))
        # Plot LFW accuracies plot
        plot_accuracy_lfw(
            log_dir="logs/lfw_{}_log_triplet.txt".format(model_architecture),
            epochs=epochs,
            figure_name="plots/lfw_accuracies_{}_triplet.png".format(
                model_architecture))
    except Exception as e:
        print(e)

    return best_distances
def main():
    dataroot = args.dataroot
    lfw_dataroot = args.lfw
    dataset_csv = args.dataset_csv
    lfw_batch_size = args.lfw_batch_size
    lfw_validation_epoch_interval = args.lfw_validation_epoch_interval
    model_architecture = args.model
    epochs = args.epochs
    num_triplets_train = args.num_triplets_train
    resume_path = args.resume_path
    batch_size = args.batch_size
    num_workers = args.num_workers
    embedding_dimension = args.embedding_dim
    pretrained = args.pretrained
    learning_rate = args.lr
    margin = args.margin
    start_epoch = 0

    # Define image data pre-processing transforms
    #   ToTensor() normalizes pixel values between [0, 1]
    #   Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalizes pixel values between [-1, 1]

    #  Size 182x182 RGB image -> Center crop size 160x160 RGB image for more model generalization
    data_transforms = transforms.Compose([
        transforms.RandomCrop(size=160),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # Size 160x160 RGB image
    lfw_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Set dataloaders
    train_dataloader = torch.utils.data.DataLoader(TripletFaceDataset(
        root_dir=dataroot,
        csv_name=dataset_csv,
        num_triplets=num_triplets_train,
        transform=data_transforms),
                                                   batch_size=batch_size,
                                                   num_workers=num_workers,
                                                   shuffle=False)

    lfw_dataloader = torch.utils.data.DataLoader(LFWDataset(
        dir=lfw_dataroot,
        pairs_path='datasets/LFW_pairs.txt',
        transform=lfw_transforms),
                                                 batch_size=lfw_batch_size,
                                                 num_workers=num_workers,
                                                 shuffle=False)
    # Instantiate model
    if model_architecture == "resnet34":
        model = Resnet34Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet50":
        model = Resnet50Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet101":
        model = Resnet101Triplet(embedding_dimension=embedding_dimension,
                                 pretrained=pretrained)
    print("Using {} model architecture.".format(model_architecture))

    # Load model to GPU or multiple GPUs if available
    flag_train_gpu = torch.cuda.is_available()
    flag_train_multi_gpu = False

    if flag_train_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model.cuda()
        flag_train_multi_gpu = True
        print('Using multi-gpu training.')
    elif flag_train_gpu and torch.cuda.device_count() == 1:
        model.cuda()
        print('Using single-gpu training.')

    # Set optimizers
    optimizer_model = torch.optim.SGD(model.parameters(), lr=learning_rate)

    # Optionally resume from a checkpoint
    if resume_path:
        if os.path.isfile(resume_path):
            print("\nLoading checkpoint {} ...".format(resume_path))

            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch']
            # In order to load state dict for optimizers correctly, model has to be loaded to gpu first
            if flag_train_multi_gpu:
                model.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint['model_state_dict'])

            optimizer_model.load_state_dict(
                checkpoint['optimizer_model_state_dict'])

            print(
                "\nCheckpoint loaded: start epoch from checkpoint = {}\nRunning for {} epochs.\n"
                .format(start_epoch, epochs - start_epoch))
        else:
            print(
                "WARNING: No checkpoint found at {}!\nTraining from scratch.".
                format(resume_path))

    # Training loop
    print(
        "\nTraining using triplet loss on {} triplets starting for {} epochs:\n"
        .format(num_triplets_train, epochs - start_epoch))
    total_time_start = time.time()

    start_epoch = start_epoch
    end_epoch = start_epoch + epochs

    l2_distance = PairwiseDistance(2).cuda()
    for epoch in range(start_epoch, end_epoch):
        flag_validate_lfw = (epoch +
                             1) % lfw_validation_epoch_interval == 0 or (
                                 epoch + 1) % epochs == 0
        triplet_loss_sum = 0

        epoch_time_start = time.time()

        # Training the model
        model.train()
        progress_bar = tqdm(enumerate(train_dataloader))

        for batch_idx, (batch_sample) in progress_bar:
            anc_img = batch_sample['anc_img'].cuda()
            pos_img = batch_sample['pos_img'].cuda()
            neg_img = batch_sample['neg_img'].cuda()

            # Forward pass - compute embeddings
            anc_embedding, pos_embedding, neg_embedding = model(
                anc_img), model(pos_img), model(neg_img)

            # Forward pass - choose hard negatives only for training
            pos_dist = l2_distance.forward(anc_embedding, pos_embedding)
            neg_dist = l2_distance.forward(anc_embedding, neg_embedding)

            all = (neg_dist - pos_dist < margin).cpu().numpy().flatten()

            hard_triplets = np.where(all == 1)
            if len(hard_triplets[0]) == 0:
                continue

            anc_hard_embedding = anc_embedding[hard_triplets].cuda()
            pos_hard_embedding = pos_embedding[hard_triplets].cuda()
            neg_hard_embedding = neg_embedding[hard_triplets].cuda()

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_hard_embedding,
                positive=pos_hard_embedding,
                negative=neg_hard_embedding).cuda()

            triplet_loss_sum += triplet_loss.item()

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

        avg_triplet_loss = triplet_loss_sum / len(train_dataloader.dataset)
        epoch_time_end = time.time()

        # Print training and validation statistics and add to log
        print('Epoch {}:\tTriplet Loss: {:.4f}\tEpoch Time: {:.2f} minutes'.
              format(epoch + 1, avg_triplet_loss,
                     (epoch_time_end - epoch_time_start) / 60))
        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [epoch + 1, avg_triplet_loss]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        # Validating on LFW dataset using KFold based on Euclidean distance metric
        if flag_validate_lfw:
            model.eval()
            with torch.no_grad():
                distances, labels = [], []

                print("Validating on LFW! ...")
                progress_bar = tqdm(enumerate(lfw_dataloader))
                for batch_index, (data_a, data_b, label) in progress_bar:
                    data_a, data_b, label = data_a.cuda(), data_b.cuda(
                    ), label.cuda()

                    output_a, output_b = model(data_a), model(data_b)
                    distance = l2_distance.forward(
                        output_a, output_b)  # Euclidean distance
                    distances.append(distance.cpu().detach().numpy())
                    labels.append(label.cpu().detach().numpy())

                labels = np.array(
                    [sublabel for label in labels for sublabel in label])
                distances = np.array([
                    subdist for distance in distances for subdist in distance
                ])

                true_positive_rate, false_positive_rate, precision, recall, accuracy, auc, best_distance_threshold, \
                    tar, far = evaluate_lfw(
                        distances=distances, labels=labels
                    )

                # Print statistics and add to log
                print(
                    "Accuracy on LFW: {:.4f}+-{:.4f}\tPrecision {:.4f}\tRecall {:.4f}\tArea Under Curve: {:.4f}\t"
                    "Best distance threshold: {:.2f}\tTAR: {:.4f}+-{:.4f} @ FAR: {:.4f}"
                    .format(np.mean(accuracy), np.std(accuracy),
                            np.mean(precision), np.mean(recall), auc,
                            best_distance_threshold, np.mean(tar), np.std(tar),
                            np.mean(far)))
                with open(
                        'logs/lfw_{}_log_triplet.txt'.format(
                            model_architecture), 'a') as f:
                    val_list = [
                        epoch + 1,
                        np.mean(accuracy),
                        np.std(accuracy),
                        np.mean(precision),
                        np.mean(recall), auc, best_distance_threshold,
                        np.mean(tar)
                    ]
                    log = '\t'.join(str(value) for value in val_list)
                    f.writelines(log + '\n')

            # Plot ROC curve
            plot_roc_lfw(
                false_positive_rate,
                true_positive_rate,
                figure_name="plots/roc_plots/roc_{}_epoch_{}_triplet.png".
                format(model_architecture, epoch + 1))

        # Save model checkpoint
        state = {
            'epoch': epoch + 1,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_lfw:
            state['best_distance_threshold'] = best_distance_threshold

        torch.save(state, 'model_{}_triplet.pt'.format(model_architecture))

    # Training loop end
    total_time_end = time.time()
    total_time_elapsed = total_time_end - total_time_start

    print("\nTraining finished: total time elapsed: {:.2f} minutes.".format(
        total_time_elapsed / 60))

    # Plot lfw accuracies plot
    print("\nPlotting plot!")
    plot_accuracy_lfw(
        log_dir="logs/lfw_{}_log_triplet.txt".format(model_architecture),
        epochs=epochs,
        figure_name="plots/lfw_accuracies_{}_triplet.png".format(
            model_architecture))
    print("\nDone.")
Esempio n. 11
0
        pos_img = batch_sample['pos_img'].cuda()
        neg_img = batch_sample['neg_img'].cuda()
        # 取出三张mask图(batch*图)
        mask_anc = batch_sample['mask_anc'].cuda()
        mask_pos = batch_sample['mask_pos'].cuda()
        mask_neg = batch_sample['mask_neg'].cuda()

        # 模型运算
        # 前向传播过程-拿模型分别跑三张图,生成embedding和loss(在训练阶段的输入是两张图,输出带loss,而验证阶段输入一张图,输出只有embedding)
        anc_embedding, anc_attention_loss = model((anc_img, mask_anc))
        pos_embedding, pos_attention_loss = model((pos_img, mask_pos))
        neg_embedding, neg_attention_loss = model((neg_img, mask_neg))

        # 寻找困难样本
        # 计算embedding的L2
        pos_dist = l2_distance.forward(anc_embedding, pos_embedding)
        neg_dist = l2_distance.forward(anc_embedding, neg_embedding)
        # 找到满足困难样本标准的样本
        all = (neg_dist - pos_dist < config['margin']).cpu().numpy().flatten()
        hard_triplets = np.where(all == 1)
        if len(hard_triplets[0]) == 0:
            continue

        # 选定困难样本——困难embedding
        anc_hard_embedding = anc_embedding[hard_triplets].cuda()
        pos_hard_embedding = pos_embedding[hard_triplets].cuda()
        neg_hard_embedding = neg_embedding[hard_triplets].cuda()
        # 选定困难样本——困难样本对应的attention loss
        hard_anc_attention_loss = anc_attention_loss[hard_triplets]
        hard_pos_attention_loss = pos_attention_loss[hard_triplets]
        hard_neg_attention_loss = neg_attention_loss[hard_triplets]
Esempio n. 12
0
def train_triplet(start_epoch, end_epoch, epochs, train_dataloader,
                  lfw_dataloader, lfw_validation_epoch_interval, model,
                  model_architecture, optimizer_model, embedding_dimension,
                  batch_size, margin, flag_train_multi_gpu, optimizer,
                  learning_rate, use_semihard_negatives):

    for epoch in range(start_epoch, end_epoch):
        flag_validate_lfw = (epoch +
                             1) % lfw_validation_epoch_interval == 0 or (
                                 epoch + 1) % epochs == 0
        triplet_loss_sum = 0
        num_valid_training_triplets = 0
        l2_distance = PairwiseDistance(p=2)

        # Training pass
        model.train()
        progress_bar = enumerate(tqdm(train_dataloader))

        for batch_idx, (batch_sample) in progress_bar:
            # Skip last iteration to avoid the problem of having different number of tensors while calculating
            #  pairwise distances (sizes of tensors must be the same for pairwise distance calculation)
            if batch_idx + 1 == len(train_dataloader):
                continue

            # Forward pass - compute embeddings
            anc_imgs = batch_sample['anc_img']
            pos_imgs = batch_sample['pos_img']
            neg_imgs = batch_sample['neg_img']

            # Concatenate the input images into one tensor because doing multiple forward passes would create
            #  weird GPU memory allocation behaviours later on during training which would cause GPU Out of Memory
            #  issues
            all_imgs = torch.cat(
                (anc_imgs, pos_imgs,
                 neg_imgs))  # Must be a tuple of Torch Tensors

            anc_embeddings, pos_embeddings, neg_embeddings, model, optimizer_model, flag_use_cpu = forward_pass(
                imgs=all_imgs,
                model=model,
                optimizer_model=optimizer_model,
                batch_idx=batch_idx,
                optimizer=optimizer,
                learning_rate=learning_rate,
                batch_size=batch_size,
                use_cpu=False)

            pos_dists = l2_distance.forward(anc_embeddings, pos_embeddings)
            neg_dists = l2_distance.forward(anc_embeddings, neg_embeddings)

            if use_semihard_negatives:
                # Semi-Hard Negative triplet selection
                #  (negative_distance - positive_distance < margin) AND (positive_distance < negative_distance)
                #   Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L295

                first_condition = (neg_dists - pos_dists <
                                   margin).cpu().numpy().flatten()
                second_condition = (pos_dists <
                                    neg_dists).cpu().numpy().flatten()
                all = (np.logical_and(first_condition, second_condition))

                semihard_negative_triplets = np.where(all == 1)
                if len(semihard_negative_triplets[0]) == 0:
                    continue

                anc_valid_embeddings = anc_embeddings[
                    semihard_negative_triplets]
                pos_valid_embeddings = pos_embeddings[
                    semihard_negative_triplets]
                neg_valid_embeddings = neg_embeddings[
                    semihard_negative_triplets]

                del anc_embeddings, pos_embeddings, neg_embeddings, pos_dists, neg_dists
                gc.collect()

            else:
                # Hard Negative triplet selection
                #  (negative_distance - positive_distance < margin)
                #   Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L296

                all = (neg_dists - pos_dists < margin).cpu().numpy().flatten()

                hard_negative_triplets = np.where(all == 1)
                if len(hard_negative_triplets[0]) == 0:
                    continue

                anc_valid_embeddings = anc_embeddings[hard_negative_triplets]
                pos_valid_embeddings = pos_embeddings[hard_negative_triplets]
                neg_valid_embeddings = neg_embeddings[hard_negative_triplets]

                del anc_embeddings, pos_embeddings, neg_embeddings, pos_dists, neg_dists
                gc.collect()

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_valid_embeddings,
                positive=pos_valid_embeddings,
                negative=neg_valid_embeddings)

            # Calculating loss and number of triplets that met the triplet selection method during the epoch
            triplet_loss_sum += triplet_loss.item()
            num_valid_training_triplets += len(anc_valid_embeddings)

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

            # Load model and optimizer back to GPU if CUDA Out of Memory Exception occurred and model and optimizer
            #  were switched to CPU
            if flag_use_cpu:
                # According to https://github.com/pytorch/pytorch/issues/2830#issuecomment-336183179
                #  In order for the optimizer to keep training the model after changing to a different type or device,
                #  optimizers have to be recreated, 'load_state_dict' can be used to restore the state from a
                #  previous copy. As such, the optimizer state dict will be saved first and then reloaded when
                #  the model's device is changed.
                torch.cuda.empty_cache()

                # Print number of valid triplets (troubleshooting out of memory causes)
                print("Number of valid triplets during OOM iteration = {}".
                      format(len(anc_valid_embeddings)))

                torch.save(
                    optimizer_model.state_dict(),
                    'model_training_checkpoints/out_of_memory_optimizer_checkpoint/optimizer_checkpoint.pt'
                )

                # Load back to CUDA
                model.cuda()

                optimizer_model = set_optimizer(optimizer=optimizer,
                                                model=model,
                                                learning_rate=learning_rate)

                optimizer_model.load_state_dict(
                    torch.load(
                        'model_training_checkpoints/out_of_memory_optimizer_checkpoint/optimizer_checkpoint.pt'
                    ))

                # Copied from https://github.com/pytorch/pytorch/issues/2830#issuecomment-336194949
                # No optimizer.cuda() available, this is the way to make an optimizer loaded with cpu tensors load
                #  with cuda tensors.
                for state in optimizer_model.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

            # Clear some memory at end of training iteration
            del triplet_loss, anc_valid_embeddings, pos_valid_embeddings, neg_valid_embeddings
            gc.collect()

        # Model only trains on triplets that fit the triplet selection method
        avg_triplet_loss = 0 if (
            num_valid_training_triplets
            == 0) else triplet_loss_sum / num_valid_training_triplets

        # Print training statistics and add to log
        print(
            'Epoch {}:\tAverage Triplet Loss: {:.4f}\tNumber of valid training triplets in epoch: {}'
            .format(epoch + 1, avg_triplet_loss, num_valid_training_triplets))

        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch + 1, avg_triplet_loss, num_valid_training_triplets
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        try:
            # Plot Triplet losses plot
            plot_triplet_losses(
                log_dir="logs/{}_log_triplet.txt".format(model_architecture),
                epochs=epochs,
                figure_name="plots/triplet_losses_{}.png".format(
                    model_architecture))
        except Exception as e:
            print(e)

        # Evaluation pass on LFW dataset
        if flag_validate_lfw:
            best_distances = validate_lfw(
                model=model,
                lfw_dataloader=lfw_dataloader,
                model_architecture=model_architecture,
                epoch=epoch,
                epochs=epochs)

        # Save model checkpoint
        state = {
            'epoch': epoch + 1,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_lfw:
            state['best_distance_threshold'] = np.mean(best_distances)

        # Save model checkpoint
        torch.save(
            state,
            'model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format(
                model_architecture, epoch + 1))
Esempio n. 13
0
def main():
    dataroot = args.dataroot
    apd_dataroot = args.apd
    apd_batch_size = args.apd_batch_size
    apd_validation_epoch_interval = args.apd_validation_epoch_interval
    model_architecture = args.model
    epochs = args.epochs
    resume_path = args.resume_path
    batch_size = args.batch_size
    num_workers = args.num_workers
    validation_dataset_split_ratio = args.valid_split
    embedding_dimension = args.embedding_dim
    pretrained = args.pretrained
    optimizer = args.optimizer
    learning_rate = args.lr
    learning_rate_center_loss = args.center_loss_lr
    center_loss_weight = args.center_loss_weight
    start_epoch = 0

    # Define image data pre-processing transforms
    #   ToTensor() normalizes pixel values between [0, 1]
    #   Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalizes pixel values between [-1, 1]

    #  Size 182x182 RGB image -> Center crop size 160x160 RGB image for more model generalization
    data_transforms = transforms.Compose([
        #transforms.RandomCrop(size=50),
        transforms.Resize(size=(160, 160)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # Size 160x160 RGB image
    apd_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Load the dataset
    dataset = torchvision.datasets.ImageFolder(root=dataroot,
                                               transform=data_transforms)

    # Subset the dataset into training and validation datasets
    num_classes = len(dataset.classes)
    print("\nNumber of classes in dataset: {}".format(num_classes))
    num_validation = int(num_classes * validation_dataset_split_ratio)
    num_train = num_classes - num_validation
    indices = list(range(num_classes))
    np.random.seed(420)
    np.random.shuffle(indices)
    train_indices = indices[:num_train]
    validation_indices = indices[num_train:]

    #train_dataset = Subset(dataset=dataset, indices=train_indices)
    #validation_dataset = Subset(dataset=dataset, indices=validation_indices)

    #print("Number of classes in training dataset: {}".format(len(train_dataset)))
    #print("Number of classes in validation dataset: {}".format(len(validation_dataset)))

    # Define the dataloaders
    train_dataloader = DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  num_workers=num_workers,
                                  shuffle=True)
    """
    validation_dataloader = DataLoader(
        dataset=validation_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False
    )
    """

    logging.info('prepare APD dataset')
    apd_dataroot = '../APD'
    apdDataset = APDDataset(
        directory=apd_dataroot,
        #pairs_path='negative_pairs.txt',
        #pairs_path='positive_pairs.txt',
        pairs_path='full.txt',
        transform=apd_transforms)
    print('apdDataset', apdDataset)

    apd_dataloader = torch.utils.data.DataLoader(dataset=apdDataset,
                                                 batch_size=apd_batch_size,
                                                 num_workers=num_workers,
                                                 shuffle=True)

    # Instantiate model
    if model_architecture == "resnet18":
        model = Resnet18Center(num_classes=num_classes,
                               embedding_dimension=embedding_dimension,
                               pretrained=pretrained)
    elif model_architecture == "resnet34":
        model = Resnet34Center(num_classes=num_classes,
                               embedding_dimension=embedding_dimension,
                               pretrained=pretrained)
    elif model_architecture == "resnet50":
        model = Resnet50Center(num_classes=num_classes,
                               embedding_dimension=embedding_dimension,
                               pretrained=pretrained)
    elif model_architecture == "resnet101":
        model = Resnet101Center(num_classes=num_classes,
                                embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "inceptionresnetv2":
        model = InceptionResnetV2Center(
            num_classes=num_classes,
            embedding_dimension=embedding_dimension,
            pretrained=pretrained)
    print("\nUsing {} model architecture.".format(model_architecture))

    # Load model to GPU or multiple GPUs if available
    flag_train_gpu = torch.cuda.is_available()
    flag_train_multi_gpu = False

    if flag_train_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model.cuda()
        flag_train_multi_gpu = True
        print('Using multi-gpu training.')
    elif flag_train_gpu and torch.cuda.device_count() == 1:
        model.cuda()
        print('Using single-gpu training.')

    # Set loss functions
    criterion_crossentropy = nn.CrossEntropyLoss().cuda()
    criterion_centerloss = CenterLoss(num_classes=num_classes,
                                      feat_dim=embedding_dimension).cuda()

    # Set optimizers
    if optimizer == "sgd":
        optimizer_model = torch.optim.SGD(model.parameters(), lr=learning_rate)
        optimizer_centerloss = torch.optim.SGD(
            criterion_centerloss.parameters(), lr=learning_rate_center_loss)

    elif optimizer == "adagrad":
        optimizer_model = torch.optim.Adagrad(model.parameters(),
                                              lr=learning_rate)
        optimizer_centerloss = torch.optim.Adagrad(
            criterion_centerloss.parameters(), lr=learning_rate_center_loss)

    elif optimizer == "rmsprop":
        optimizer_model = torch.optim.RMSprop(model.parameters(),
                                              lr=learning_rate)
        optimizer_centerloss = torch.optim.RMSprop(
            criterion_centerloss.parameters(), lr=learning_rate_center_loss)

    elif optimizer == "adam":
        optimizer_model = torch.optim.Adam(model.parameters(),
                                           lr=learning_rate)
        optimizer_centerloss = torch.optim.Adam(
            criterion_centerloss.parameters(), lr=learning_rate_center_loss)

    # Set learning rate decay scheduler
    learning_rate_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer=optimizer_model, milestones=[150, 225], gamma=0.1)

    # Optionally resume from a checkpoint
    if resume_path:

        if os.path.isfile(resume_path):
            print("\nLoading checkpoint {} ...".format(resume_path))

            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch']

            # In order to load state dict for optimizers correctly, model has to be loaded to gpu first
            if flag_train_multi_gpu:
                model.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint['model_state_dict'])

            optimizer_model.load_state_dict(
                checkpoint['optimizer_model_state_dict'])
            optimizer_centerloss.load_state_dict(
                checkpoint['optimizer_centerloss_state_dict'])
            learning_rate_scheduler.load_state_dict(
                checkpoint['learning_rate_scheduler_state_dict'])

            print(
                "\nCheckpoint loaded: start epoch from checkpoint = {}\nRunning for {} epochs.\n"
                .format(start_epoch, epochs - start_epoch))
        else:
            print(
                "WARNING: No checkpoint found at {}!\nTraining from scratch.".
                format(resume_path))

    # Start Training loop
    print(
        "\nTraining using cross entropy loss with center loss starting for {} epochs:\n"
        .format(epochs - start_epoch))

    total_time_start = time.time()
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs

    BATCH_NUM = len(train_dataloader.dataset) / batch_size
    APD_BATCH_NUM = len(apd_dataloader.dataset) / apd_batch_size

    for epoch in range(start_epoch, end_epoch):
        epoch_time_start = time.time()

        #flag_validate_apd = (epoch + 1) % lfw_validation_epoch_interval == 0 or (epoch + 1) % epochs == 0
        flag_validate_apd = True
        train_loss_sum = 0
        validation_loss_sum = 0

        # Training the model
        model.train()
        learning_rate_scheduler.step()
        #progress_bar = enumerate(tqdm(train_dataloader))
        progress_bar = enumerate(train_dataloader)

        for batch_index, (data, labels) in progress_bar:
            #break

            data, labels = data.cuda(), labels.cuda()
            print(data.shape)
            #print('label', labels)

            # Forward pass
            if flag_train_multi_gpu:
                embedding, logits = model.module.forward_training(data)
            else:
                embedding, logits = model.forward_training(data)

            # Calculate losses
            cross_entropy_loss = criterion_crossentropy(
                logits.cuda(), labels.cuda())
            center_loss = criterion_centerloss(embedding, labels)
            loss = (center_loss * center_loss_weight) + cross_entropy_loss
            #loss = cross_entropy_loss
            logging.info("epoch:{}/{} batch_idx:{}/{} loss:{}".format(
                epoch, end_epoch, batch_index, BATCH_NUM, loss))

            # Backward pass
            optimizer_centerloss.zero_grad()
            optimizer_model.zero_grad()
            loss.backward()
            optimizer_centerloss.step()
            optimizer_model.step()

            # Remove center_loss_weight impact on the learning of center vectors
            #for param in criterion_centerloss.parameters():
            #    param.grad.data *= (1. / center_loss_weight)

            # Update training loss sum
            train_loss_sum += loss.item() * data.size(0)

        # Calculate average losses in epoch
        avg_train_loss = train_loss_sum / len(train_dataloader.dataset)
        """
        avg_validation_loss = validation_loss_sum / len(validation_dataloader.dataset)
        """

        # Calculate training performance statistics in epoch
        #classification_accuracy = correct * 100. / total
        #classification_error = 100. - classification_accuracy

        epoch_time_end = time.time()

        print('Epoch {}:\t Average Training Loss: {:.4f}\t'.format(
            epoch + 1, avg_train_loss))

        with open('logs/{}_log_center.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch + 1, avg_train_loss
                #avg_validation_loss,
                #classification_accuracy.item(),
                #classification_error.item()
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        try:
            # Plot plot for Cross Entropy Loss and Center Loss on training and validation sets
            plot_training_validation_losses_center(
                log_dir="logs/{}_log_center.txt".format(model_architecture),
                epochs=epochs,
                figure_name="plots/training_validation_losses_{}_center.png".
                format(model_architecture))
        except Exception as e:
            print(e)

        # Validating on LFW dataset using KFold based on Euclidean distance metric
        if flag_validate_apd:

            model.eval()
            with torch.no_grad():
                l2_distance = PairwiseDistance(2).cuda()
                distances, labels = [], []

                print("Validating on APD! ...")
                #progress_bar = enumerate(tqdm(apd_dataloader))
                progress_bar = enumerate(apd_dataloader)

                for batch_index, (data_a, data_b, label) in progress_bar:
                    logging.info("epoch:{}/{} batch_idx:{}/{}".format(
                        epoch, end_epoch, batch_index, APD_BATCH_NUM))
                    data_a, data_b, label = data_a.cuda(), data_b.cuda(
                    ), label.cuda()

                    output_a, output_b = model(data_a), model(data_b)
                    distance = l2_distance.forward(
                        output_a, output_b)  # Euclidean distance
                    #print(distance)
                    #print(label)

                    distances.append(distance.cpu().detach().numpy())
                    labels.append(label.cpu().detach().numpy())

                labels = np.array(
                    [sublabel for label in labels for sublabel in label])
                distances = np.array([
                    subdist for distance in distances for subdist in distance
                ])

                true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \
                    tar, far = evaluate_lfw(
                        distances=distances,
                        labels=labels
                     )
                # Print statistics and add to log
                print(
                    "Accuracy on LFW: {:.4f}+-{:.4f}\tPrecision {:.4f}+-{:.4f}\tRecall {:.4f}+-{:.4f}\tROC Area Under Curve: {:.4f}\tBest distance threshold: {:.2f}+-{:.2f}\tTAR: {:.4f}+-{:.4f} @ FAR: {:.4f}"
                    .format(np.mean(accuracy), np.std(accuracy),
                            np.mean(precision), np.std(precision),
                            np.mean(recall), np.std(recall), roc_auc,
                            np.mean(best_distances), np.std(best_distances),
                            np.mean(tar), np.std(tar), np.mean(far)))
                with open(
                        'logs/lfw_{}_log_center.txt'.format(
                            model_architecture), 'a') as f:
                    val_list = [
                        epoch + 1,
                        np.mean(accuracy),
                        np.std(accuracy),
                        np.mean(precision),
                        np.std(precision),
                        np.mean(recall),
                        np.std(recall), roc_auc,
                        np.mean(best_distances),
                        np.std(best_distances),
                        np.mean(tar)
                    ]
                    log = '\t'.join(str(value) for value in val_list)
                    f.writelines(log + '\n')

            try:
                # Plot ROC curve
                plot_roc_lfw(
                    false_positive_rate=false_positive_rate,
                    true_positive_rate=true_positive_rate,
                    figure_name=
                    "plots/roc_plots_center/roc_{}_epoch_{}_center.png".format(
                        model_architecture, epoch + 1),
                    epochNum=epoch)
                # Plot LFW accuracies plot
                plot_accuracy_lfw(
                    log_dir="logs/lfw_{}_log_center.txt".format(
                        model_architecture),
                    epochs=epochs,
                    figure_name="plots/lfw_accuracies_{}_center.png".format(
                        model_architecture))
            except Exception as e:
                print(e)

        # Save model checkpoint
        state = {
            'epoch':
            epoch + 1,
            'num_classes':
            num_classes,
            'embedding_dimension':
            embedding_dimension,
            'batch_size_training':
            batch_size,
            'model_state_dict':
            model.state_dict(),
            'model_architecture':
            model_architecture,
            'optimizer_model_state_dict':
            optimizer_model.state_dict(),
            'optimizer_centerloss_state_dict':
            optimizer_centerloss.state_dict(),
            'learning_rate_scheduler_state_dict':
            learning_rate_scheduler.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_apd:
            state['best_distance_threshold'] = np.mean(best_distances)

        # Save model checkpoint
        torch.save(
            state, 'center_checkpoints/model_{}_center_epoch_{}.pt'.format(
                model_architecture, epoch + 1))

    # Training loop end
    total_time_end = time.time()
    total_time_elapsed = total_time_end - total_time_start
    print("\nTraining finished: total time elapsed: {:.2f} hours.".format(
        total_time_elapsed / 3600))