def __init__(self, model_name, loss_name, margin, num_classes):
        """

        :param model_name: 模型的名称;类型为str
        :param loss_name: 损失的名称;类型为str
        :param margin: TripletLoss中的参数;类型为float
        :param num_classes: 网络的参数
        """
        super(Loss, self).__init__()
        self.model_name = model_name
        self.loss_name = loss_name
        self.loss_struct = []

        for loss in self.loss_name.split('+'):
            weight, loss_type = loss.split('*')
            if loss_type == 'CrossEntropy':
                loss_function = nn.CrossEntropyLoss()
            elif loss_type == 'SmoothCrossEntropy':
                loss_function = CrossEntropyLabelSmooth(
                    num_classes=num_classes)
            elif loss_type == 'Triplet':
                loss_function = TripletLoss(margin)
            else:
                assert "loss: {} not support yet".format(self.loss_name)

            self.loss_struct.append({
                'type': loss_type,
                'weight': float(weight),
                'function': loss_function
            })

        # 如果有多个损失函数,在加上一个求和操作
        if len(self.loss_struct) > 1:
            self.loss_struct.append({
                'type': 'Total',
                'weight': 0,
                'function': None
            })

        self.loss_module = nn.ModuleList([
            l['function'] for l in self.loss_struct
            if l['function'] is not None
        ])

        # self.log的维度为[1, len(self.loss)],前面几个分别存放某次迭代各个损失函数的损失值,最后一个存放某次迭代损失值之和
        self.log, self.log_sum = torch.zeros(len(
            self.loss_struct)), torch.zeros(len(self.loss_struct))

        if torch.cuda.is_available():
            self.loss_module = torch.nn.DataParallel(self.loss_module)
            self.loss_module.cuda()
Esempio n. 2
0
def build_losses(config):
    # Build classification loss
    if config.LOSS.CLA_LOSS == 'crossentropy':
        criterion_cla = nn.CrossEntropyLoss()
    elif config.LOSS.CLA_LOSS == 'crossentropylabelsmooth':
        criterion_cla = CrossEntropyLabelSmooth()
    elif config.LOSS.CLA_LOSS == 'arcface':
        criterion_cla = ArcFaceLoss(scale=config.LOSS.CLA_S,
                                    margin=config.LOSS.CLA_M)
    elif config.LOSS.CLA_LOSS == 'cosface':
        criterion_cla = CosFaceLoss(scale=config.LOSS.CLA_S,
                                    margin=config.LOSS.CLA_M)
    elif config.LOSS.CLA_LOSS == 'circle':
        criterion_cla = CircleLoss(scale=config.LOSS.CLA_S,
                                   margin=config.LOSS.CLA_M)
    else:
        raise KeyError("Invalid classification loss: '{}'".format(
            config.LOSS.CLA_LOSS))

    # Build pairwise loss
    if config.LOSS.PAIR_LOSS == 'triplet':
        criterion_pair = TripletLoss(margin=config.LOSS.PAIR_M,
                                     distance=config.TEST.DISTANCE)
    elif config.LOSS.PAIR_LOSS == 'contrastive':
        criterion_pair = ContrastiveLoss(scale=config.LOSS.PAIR_S)
    elif config.LOSS.PAIR_LOSS == 'cosface':
        criterion_pair = PairwiseCosFaceLoss(scale=config.LOSS.PAIR_S,
                                             margin=config.LOSS.PAIR_M)
    elif config.LOSS.PAIR_LOSS == 'circle':
        criterion_pair = PairwiseCircleLoss(scale=config.LOSS.PAIR_S,
                                            margin=config.LOSS.PAIR_M)
    else:
        raise KeyError("Invalid pairwise loss: '{}'".format(
            config.LOSS.PAIR_LOSS))

    return criterion_cla, criterion_pair
Esempio n. 3
0
 def _init_criterion(self):
     self.id_loss = nn.CrossEntropyLoss()
     self.reconst_loss = nn.L1Loss()
     self.triplet_loss = TripletLoss(0.5)
Esempio n. 4
0
def train_triplet(start_epoch, end_epoch, epochs, train_dataloader,
                  lfw_dataloader, lfw_validation_epoch_interval, model,
                  model_architecture, optimizer_model, embedding_dimension,
                  batch_size, margin, flag_train_multi_gpu, optimizer,
                  learning_rate):

    for epoch in range(start_epoch, end_epoch):
        flag_validate_lfw = (epoch +
                             1) % lfw_validation_epoch_interval == 0 or (
                                 epoch + 1) % epochs == 0
        triplet_loss_sum = 0
        num_valid_training_triplets = 0
        l2_distance = PairwiseDistance(2)

        # Training pass
        model.train()
        progress_bar = enumerate(tqdm(train_dataloader))

        for batch_idx, (batch_sample) in progress_bar:

            # Forward pass - compute embeddings
            anc_imgs = batch_sample['anc_img']
            pos_imgs = batch_sample['pos_img']
            neg_imgs = batch_sample['neg_img']

            anc_embeddings, pos_embeddings, neg_embeddings, model, optimizer_model, flag_use_cpu = forward_pass(
                anc_imgs=anc_imgs,
                pos_imgs=pos_imgs,
                neg_imgs=neg_imgs,
                model=model,
                optimizer_model=optimizer_model,
                batch_idx=batch_idx,
                optimizer=optimizer,
                learning_rate=learning_rate,
                use_cpu=False)

            # Forward pass - choose hard negatives only for training
            pos_dists = l2_distance.forward(anc_embeddings, pos_embeddings)
            neg_dists = l2_distance.forward(anc_embeddings, neg_embeddings)

            all = (neg_dists - pos_dists < margin).cpu().numpy().flatten()

            hard_triplets = np.where(all == 1)
            if len(hard_triplets[0]) == 0:
                continue

            anc_hard_embeddings = anc_embeddings[hard_triplets]
            pos_hard_embeddings = pos_embeddings[hard_triplets]
            neg_hard_embeddings = neg_embeddings[hard_triplets]

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_hard_embeddings,
                positive=pos_hard_embeddings,
                negative=neg_hard_embeddings)

            # Calculating loss
            triplet_loss_sum += triplet_loss.item()
            num_valid_training_triplets += len(anc_hard_embeddings)

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

            # Load model and optimizer back to GPU if CUDA Out of Memory Exception occured and model and optimizer
            #  were switched to CPU
            if flag_use_cpu:
                # According to https://github.com/pytorch/pytorch/issues/2830#issuecomment-336183179
                #  In order for the optimizer to keep training the model after changing to a different type or device,
                #  optimizers have to be recreated, 'load_state_dict' can be used to restore the state from a
                #  previous copy. As such, the optimizer state dict will be saved first and then reloaded when
                #  the model's device is changed.

                # Load back to CUDA
                torch.save(
                    optimizer_model.state_dict(),
                    'model_training_checkpoints/out_of_memory_optimizer_checkpoint/optimizer_checkpoint.pt'
                )

                model = model.cuda()

                optimizer_model = set_optimizer(optimizer=optimizer,
                                                model=model,
                                                learning_rate=learning_rate)

                optimizer_model.load_state_dict(
                    torch.load(
                        'model_training_checkpoints/out_of_memory_optimizer_checkpoint/optimizer_checkpoint.pt'
                    ))

                # Copied from https://github.com/pytorch/pytorch/issues/2830#issuecomment-336194949
                # No optimizer.cuda() available, this is the way to make an optimizer loaded with cpu tensors load
                #  with cuda tensors.
                for state in optimizer_model.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

                # Ensure model is correctly set to be trainable
                model.train()

        # Model only trains on hard negative triplets
        avg_triplet_loss = 0 if (
            num_valid_training_triplets
            == 0) else triplet_loss_sum / num_valid_training_triplets

        # Print training statistics and add to log
        print(
            'Epoch {}:\tAverage Triplet Loss: {:.4f}\tNumber of valid training triplets in epoch: {}'
            .format(epoch + 1, avg_triplet_loss, num_valid_training_triplets))
        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch + 1, avg_triplet_loss, num_valid_training_triplets
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        try:
            # Plot Triplet losses plot
            plot_triplet_losses(
                log_dir="logs/{}_log_triplet.txt".format(model_architecture),
                epochs=epochs,
                figure_name="plots/triplet_losses_{}.png".format(
                    model_architecture))
        except Exception as e:
            print(e)

        # Evaluation pass on LFW dataset
        if flag_validate_lfw:
            best_distances = validate_lfw(
                model=model,
                lfw_dataloader=lfw_dataloader,
                model_architecture=model_architecture,
                epoch=epoch,
                epochs=epochs)

        # Save model checkpoint
        state = {
            'epoch': epoch + 1,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_lfw:
            state['best_distance_threshold'] = np.mean(best_distances)

        # Save model checkpoint
        torch.save(
            state,
            'model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format(
                model_architecture, epoch + 1))
def main():
    dataroot = args.dataroot
    lfw_dataroot = args.lfw
    training_dataset_csv_path = args.training_dataset_csv_path
    epochs = args.epochs
    iterations_per_epoch = args.iterations_per_epoch
    model_architecture = args.model_architecture
    pretrained = args.pretrained
    embedding_dimension = args.embedding_dimension
    num_human_identities_per_batch = args.num_human_identities_per_batch
    batch_size = args.batch_size
    lfw_batch_size = args.lfw_batch_size
    resume_path = args.resume_path
    num_workers = args.num_workers
    optimizer = args.optimizer
    learning_rate = args.learning_rate
    margin = args.margin
    image_size = args.image_size
    use_semihard_negatives = args.use_semihard_negatives
    training_triplets_path = args.training_triplets_path
    flag_training_triplets_path = False
    start_epoch = 0

    if training_triplets_path is not None:
        flag_training_triplets_path = True  # Load triplets file for the first training epoch

    # Define image data pre-processing transforms
    #   ToTensor() normalizes pixel values between [0, 1]
    #   Normalize(mean=[0.6071, 0.4609, 0.3944], std=[0.2457, 0.2175, 0.2129]) according to the calculated glint360k
    #   dataset with tightly-cropped faces dataset RGB channels' mean and std values by
    #   calculate_glint360k_rgb_mean_std.py in 'datasets' folder.
    data_transforms = transforms.Compose([
        transforms.Resize(size=image_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.6071, 0.4609, 0.3944],
                             std=[0.2457, 0.2175, 0.2129])
    ])

    lfw_transforms = transforms.Compose([
        transforms.Resize(size=image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.6071, 0.4609, 0.3944],
                             std=[0.2457, 0.2175, 0.2129])
    ])

    lfw_dataloader = torch.utils.data.DataLoader(dataset=LFWDataset(
        dir=lfw_dataroot,
        pairs_path='datasets/LFW_pairs.txt',
        transform=lfw_transforms),
                                                 batch_size=lfw_batch_size,
                                                 num_workers=num_workers,
                                                 shuffle=False)

    # Instantiate model
    model = set_model_architecture(model_architecture=model_architecture,
                                   pretrained=pretrained,
                                   embedding_dimension=embedding_dimension)

    # Load model to GPU or multiple GPUs if available
    model, flag_train_multi_gpu = set_model_gpu_mode(model)

    # Set optimizer
    optimizer_model = set_optimizer(optimizer=optimizer,
                                    model=model,
                                    learning_rate=learning_rate)

    # Resume from a model checkpoint
    if resume_path:
        if os.path.isfile(resume_path):
            print("Loading checkpoint {} ...".format(resume_path))
            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch'] + 1
            optimizer_model.load_state_dict(
                checkpoint['optimizer_model_state_dict'])

            # In order to load state dict for optimizers correctly, model has to be loaded to gpu first
            if flag_train_multi_gpu:
                model.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint['model_state_dict'])
            print("Checkpoint loaded: start epoch from checkpoint = {}".format(
                start_epoch))
        else:
            print(
                "WARNING: No checkpoint found at {}!\nTraining from scratch.".
                format(resume_path))

    if use_semihard_negatives:
        print("Using Semi-Hard negative triplet selection!")
    else:
        print("Using Hard negative triplet selection!")

    start_epoch = start_epoch

    print("Training using triplet loss starting for {} epochs:\n".format(
        epochs - start_epoch))

    for epoch in range(start_epoch, epochs):
        num_valid_training_triplets = 0
        l2_distance = PairwiseDistance(p=2)
        _training_triplets_path = None

        if flag_training_triplets_path:
            _training_triplets_path = training_triplets_path
            flag_training_triplets_path = False  # Only load triplets file for the first epoch

        # Re-instantiate training dataloader to generate a triplet list for this training epoch
        train_dataloader = torch.utils.data.DataLoader(
            dataset=TripletFaceDataset(
                root_dir=dataroot,
                training_dataset_csv_path=training_dataset_csv_path,
                num_triplets=iterations_per_epoch * batch_size,
                num_human_identities_per_batch=num_human_identities_per_batch,
                triplet_batch_size=batch_size,
                epoch=epoch,
                training_triplets_path=_training_triplets_path,
                transform=data_transforms),
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=
            False  # Shuffling for triplets with set amount of human identities per batch is not required
        )

        # Training pass
        model.train()
        progress_bar = enumerate(tqdm(train_dataloader))

        for batch_idx, (batch_sample) in progress_bar:

            # Forward pass - compute embeddings
            anc_imgs = batch_sample['anc_img']
            pos_imgs = batch_sample['pos_img']
            neg_imgs = batch_sample['neg_img']

            # Concatenate the input images into one tensor because doing multiple forward passes would create
            #  weird GPU memory allocation behaviours later on during training which would cause GPU Out of Memory
            #  issues
            all_imgs = torch.cat(
                (anc_imgs, pos_imgs,
                 neg_imgs))  # Must be a tuple of Torch Tensors

            anc_embeddings, pos_embeddings, neg_embeddings, model = forward_pass(
                imgs=all_imgs, model=model, batch_size=batch_size)

            pos_dists = l2_distance.forward(anc_embeddings, pos_embeddings)
            neg_dists = l2_distance.forward(anc_embeddings, neg_embeddings)

            if use_semihard_negatives:
                # Semi-Hard Negative triplet selection
                #  (negative_distance - positive_distance < margin) AND (positive_distance < negative_distance)
                #   Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L295
                first_condition = (neg_dists - pos_dists <
                                   margin).cpu().numpy().flatten()
                second_condition = (pos_dists <
                                    neg_dists).cpu().numpy().flatten()
                all = (np.logical_and(first_condition, second_condition))
                valid_triplets = np.where(all == 1)
            else:
                # Hard Negative triplet selection
                #  (negative_distance - positive_distance < margin)
                #   Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L296
                all = (neg_dists - pos_dists < margin).cpu().numpy().flatten()
                valid_triplets = np.where(all == 1)

            anc_valid_embeddings = anc_embeddings[valid_triplets]
            pos_valid_embeddings = pos_embeddings[valid_triplets]
            neg_valid_embeddings = neg_embeddings[valid_triplets]

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_valid_embeddings,
                positive=pos_valid_embeddings,
                negative=neg_valid_embeddings)

            # Calculating number of triplets that met the triplet selection method during the epoch
            num_valid_training_triplets += len(anc_valid_embeddings)

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

        # Print training statistics for epoch and add to log
        print(
            'Epoch {}:\tNumber of valid training triplets in epoch: {}'.format(
                epoch, num_valid_training_triplets))

        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [epoch, num_valid_training_triplets]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        # Evaluation pass on LFW dataset
        best_distances = validate_lfw(model=model,
                                      lfw_dataloader=lfw_dataloader,
                                      model_architecture=model_architecture,
                                      epoch=epoch)

        # Save model checkpoint
        state = {
            'epoch': epoch,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict(),
            'best_distance_threshold': np.mean(best_distances)
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # Save model checkpoint
        torch.save(
            state,
            'model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format(
                model_architecture, epoch))
Esempio n. 6
0
# record every run
copyfile('./train.py', dir_name + '/train.py')
copyfile('models/base_model.py', dir_name + '/base_model.py')
if opt.LSTM:
    copyfile('models/lstm_model.py', dir_name + '/lstm_model.py')
if opt.GGNN:
    copyfile('models/ggnn_model.py', dir_name + '/ggnn_model.py')

# save opts
with open('%s/opts.yaml' % dir_name, 'w') as fp:
    yaml.dump(vars(opt), fp, default_flow_style=False)

# model to gpu
model = model.cuda()

triplet = TripletLoss(0.3)
xent = CrossEntropyLabelSmooth(num_classes=opt.nclasses)


def loss_func(score, feat, target):
    if opt.use_triplet_loss:
        if opt.label_smoothing:
            return xent(score, target) + triplet(feat, target)[0]
        else:
            return F.cross_entropy(score, target) + triplet(feat, target)[0]
    else:
        if opt.label_smoothing:
            return xent(score, target)
        else:
            return F.cross_entropy(score, target)
def train_triplet(start_epoch, end_epoch, epochs, train_dataloader,
                  lfw_dataloader, lfw_validation_epoch_interval, model,
                  model_architecture, optimizer_model, embedding_dimension,
                  batch_size, margin, flag_train_multi_gpu, optimizer,
                  learning_rate, use_semihard_negatives):

    for epoch in range(start_epoch, end_epoch):
        flag_validate_lfw = (epoch +
                             1) % lfw_validation_epoch_interval == 0 or (
                                 epoch + 1) % epochs == 0
        triplet_loss_sum = 0
        num_valid_training_triplets = 0
        l2_distance = PairwiseDistance(p=2)

        # Training pass
        model.train()
        progress_bar = enumerate(tqdm(train_dataloader))

        for batch_idx, (batch_sample) in progress_bar:
            # Skip last iteration to avoid the problem of having different number of tensors while calculating
            #  pairwise distances (sizes of tensors must be the same for pairwise distance calculation)
            if batch_idx + 1 == len(train_dataloader):
                continue

            # Forward pass - compute embeddings
            anc_imgs = batch_sample['anc_img']
            pos_imgs = batch_sample['pos_img']
            neg_imgs = batch_sample['neg_img']

            # Concatenate the input images into one tensor because doing multiple forward passes would create
            #  weird GPU memory allocation behaviours later on during training which would cause GPU Out of Memory
            #  issues
            all_imgs = torch.cat(
                (anc_imgs, pos_imgs,
                 neg_imgs))  # Must be a tuple of Torch Tensors

            anc_embeddings, pos_embeddings, neg_embeddings, model, optimizer_model, flag_use_cpu = forward_pass(
                imgs=all_imgs,
                model=model,
                optimizer_model=optimizer_model,
                batch_idx=batch_idx,
                optimizer=optimizer,
                learning_rate=learning_rate,
                batch_size=batch_size,
                use_cpu=False)

            pos_dists = l2_distance.forward(anc_embeddings, pos_embeddings)
            neg_dists = l2_distance.forward(anc_embeddings, neg_embeddings)

            if use_semihard_negatives:
                # Semi-Hard Negative triplet selection
                #  (negative_distance - positive_distance < margin) AND (positive_distance < negative_distance)
                #   Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L295

                first_condition = (neg_dists - pos_dists <
                                   margin).cpu().numpy().flatten()
                second_condition = (pos_dists <
                                    neg_dists).cpu().numpy().flatten()
                all = (np.logical_and(first_condition, second_condition))

                semihard_negative_triplets = np.where(all == 1)
                if len(semihard_negative_triplets[0]) == 0:
                    continue

                anc_valid_embeddings = anc_embeddings[
                    semihard_negative_triplets]
                pos_valid_embeddings = pos_embeddings[
                    semihard_negative_triplets]
                neg_valid_embeddings = neg_embeddings[
                    semihard_negative_triplets]

                del anc_embeddings, pos_embeddings, neg_embeddings, pos_dists, neg_dists
                gc.collect()

            else:
                # Hard Negative triplet selection
                #  (negative_distance - positive_distance < margin)
                #   Based on: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py#L296

                all = (neg_dists - pos_dists < margin).cpu().numpy().flatten()

                hard_negative_triplets = np.where(all == 1)
                if len(hard_negative_triplets[0]) == 0:
                    continue

                anc_valid_embeddings = anc_embeddings[hard_negative_triplets]
                pos_valid_embeddings = pos_embeddings[hard_negative_triplets]
                neg_valid_embeddings = neg_embeddings[hard_negative_triplets]

                del anc_embeddings, pos_embeddings, neg_embeddings, pos_dists, neg_dists
                gc.collect()

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_valid_embeddings,
                positive=pos_valid_embeddings,
                negative=neg_valid_embeddings)

            # Calculating loss and number of triplets that met the triplet selection method during the epoch
            triplet_loss_sum += triplet_loss.item()
            num_valid_training_triplets += len(anc_valid_embeddings)

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

            # Load model and optimizer back to GPU if CUDA Out of Memory Exception occurred and model and optimizer
            #  were switched to CPU
            if flag_use_cpu:
                # According to https://github.com/pytorch/pytorch/issues/2830#issuecomment-336183179
                #  In order for the optimizer to keep training the model after changing to a different type or device,
                #  optimizers have to be recreated, 'load_state_dict' can be used to restore the state from a
                #  previous copy. As such, the optimizer state dict will be saved first and then reloaded when
                #  the model's device is changed.
                torch.cuda.empty_cache()

                # Print number of valid triplets (troubleshooting out of memory causes)
                print("Number of valid triplets during OOM iteration = {}".
                      format(len(anc_valid_embeddings)))

                torch.save(
                    optimizer_model.state_dict(),
                    'model_training_checkpoints/out_of_memory_optimizer_checkpoint/optimizer_checkpoint.pt'
                )

                # Load back to CUDA
                model.cuda()

                optimizer_model = set_optimizer(optimizer=optimizer,
                                                model=model,
                                                learning_rate=learning_rate)

                optimizer_model.load_state_dict(
                    torch.load(
                        'model_training_checkpoints/out_of_memory_optimizer_checkpoint/optimizer_checkpoint.pt'
                    ))

                # Copied from https://github.com/pytorch/pytorch/issues/2830#issuecomment-336194949
                # No optimizer.cuda() available, this is the way to make an optimizer loaded with cpu tensors load
                #  with cuda tensors.
                for state in optimizer_model.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

            # Clear some memory at end of training iteration
            del triplet_loss, anc_valid_embeddings, pos_valid_embeddings, neg_valid_embeddings
            gc.collect()

        # Model only trains on triplets that fit the triplet selection method
        avg_triplet_loss = 0 if (
            num_valid_training_triplets
            == 0) else triplet_loss_sum / num_valid_training_triplets

        # Print training statistics and add to log
        print(
            'Epoch {}:\tAverage Triplet Loss: {:.4f}\tNumber of valid training triplets in epoch: {}'
            .format(epoch + 1, avg_triplet_loss, num_valid_training_triplets))

        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch + 1, avg_triplet_loss, num_valid_training_triplets
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        try:
            # Plot Triplet losses plot
            plot_triplet_losses(
                log_dir="logs/{}_log_triplet.txt".format(model_architecture),
                epochs=epochs,
                figure_name="plots/triplet_losses_{}.png".format(
                    model_architecture))
        except Exception as e:
            print(e)

        # Evaluation pass on LFW dataset
        if flag_validate_lfw:
            best_distances = validate_lfw(
                model=model,
                lfw_dataloader=lfw_dataloader,
                model_architecture=model_architecture,
                epoch=epoch,
                epochs=epochs)

        # Save model checkpoint
        state = {
            'epoch': epoch + 1,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_lfw:
            state['best_distance_threshold'] = np.mean(best_distances)

        # Save model checkpoint
        torch.save(
            state,
            'model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format(
                model_architecture, epoch + 1))
def train_triplet(start_epoch, end_epoch, epochs, train_dataloader,
                  lfw_dataloader, lfw_validation_epoch_interval, model,
                  model_architecture, optimizer_model, embedding_dimension,
                  batch_size, margin, flag_train_multi_gpu):

    for epoch in range(start_epoch, end_epoch):
        flag_validate_lfw = (epoch +
                             1) % lfw_validation_epoch_interval == 0 or (
                                 epoch + 1) % epochs == 0
        triplet_loss_sum = 0
        num_valid_training_triplets = 0
        l2_distance = PairwiseDistance(2).cuda()

        # Training pass
        model.train()
        progress_bar = enumerate(tqdm(train_dataloader))

        for batch_idx, (batch_sample) in progress_bar:

            anc_img = batch_sample['anc_img'].cuda()
            pos_img = batch_sample['pos_img'].cuda()
            neg_img = batch_sample['neg_img'].cuda()

            # Forward pass - compute embeddings
            anc_embedding, pos_embedding, neg_embedding = model(
                anc_img), model(pos_img), model(neg_img)

            # Forward pass - choose hard negatives only for training
            pos_dist = l2_distance.forward(anc_embedding, pos_embedding)
            neg_dist = l2_distance.forward(anc_embedding, neg_embedding)

            all = (neg_dist - pos_dist < margin).cpu().numpy().flatten()

            hard_triplets = np.where(all == 1)
            if len(hard_triplets[0]) == 0:
                continue

            anc_hard_embedding = anc_embedding[hard_triplets].cuda()
            pos_hard_embedding = pos_embedding[hard_triplets].cuda()
            neg_hard_embedding = neg_embedding[hard_triplets].cuda()

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_hard_embedding,
                positive=pos_hard_embedding,
                negative=neg_hard_embedding).cuda()

            # Calculating loss
            triplet_loss_sum += triplet_loss.item()
            num_valid_training_triplets += len(anc_hard_embedding)

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

        # Model only trains on hard negative triplets
        avg_triplet_loss = 0 if (
            num_valid_training_triplets
            == 0) else triplet_loss_sum / num_valid_training_triplets

        # Print training statistics and add to log
        print(
            'Epoch {}:\tAverage Triplet Loss: {:.4f}\tNumber of valid training triplets in epoch: {}'
            .format(epoch + 1, avg_triplet_loss, num_valid_training_triplets))
        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch + 1, avg_triplet_loss, num_valid_training_triplets
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        try:
            # Plot Triplet losses plot
            plot_triplet_losses(
                log_dir="logs/{}_log_triplet.txt".format(model_architecture),
                epochs=epochs,
                figure_name="plots/triplet_losses_{}.png".format(
                    model_architecture))
        except Exception as e:
            print(e)

        # Evaluation pass on LFW dataset
        if flag_validate_lfw:
            best_distances = validate_lfw(
                model=model,
                lfw_dataloader=lfw_dataloader,
                model_architecture=model_architecture,
                epoch=epoch,
                epochs=epochs)

        # Save model checkpoint
        state = {
            'epoch': epoch + 1,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_lfw:
            state['best_distance_threshold'] = np.mean(best_distances)

        # Save model checkpoint
        torch.save(
            state,
            'Model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format(
                model_architecture, epoch + 1))
Esempio n. 9
0
def main():
    writer = get_summary_writer(args)
    data_transform = augmentation(args.image_size, train=True)

    time_id, output_folder = log_experience(args, data_transform)

    data = pd.read_csv(data_files['data'])
    if bool(args.pseudo_label):
        bootstrapped_data = pd.read_csv(data_files['pseudo_labels'])
        data = pd.concat([data, bootstrapped_data], axis=0)
        data['file_id'] = data.index.tolist()

    mapping_filename_path = dict(zip(data['filename'], data['full_path']))
    classes = data.folder.unique()
    mapping_label_id = dict(zip(classes, range(len(classes))))

    num_classes = data.folder.nunique()
    mapping_files_to_global_id = dict(
        zip(data.full_path.tolist(), data.file_id.tolist()))
    paths = data.full_path.tolist()
    labels_to_samples = data.groupby('folder').agg(list)['filename'].to_dict()

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model_params = {
        'embedding_dim': args.embedding_dim,
        'num_classes': num_classes,
        'image_size': args.image_size,
        'archi': args.archi,
        'pretrained': bool(args.pretrained),
        'dropout': args.dropout,
        'alpha': args.alpha,
        'gap': args.gap
    }

    model = model_factory.get_model(**model_params)

    if args.weights is not None:
        print('loading pre-trained weights and changing input size ...')
        weights = torch.load(args.weights)
        if 'state_dict' in weights.keys():
            weights = weights['state_dict']
        if args.archi != 'densenet121':

            if bool(args.pop_fc):
                weights.pop('model.fc.weight')
                weights.pop('model.fc.bias')

            try:
                weights.pop('model.classifier.weight')
                weights.pop('model.classifier.bias')
            except:
                print('no classifier. skipping.')
        model.load_state_dict(weights, strict=False)
    model.to(device)

    dataset = WhalesData(paths=paths,
                         bbox=data_files['bbox_train'],
                         mapping_label_id=mapping_label_id,
                         transform=data_transform,
                         crop=bool(args.crop))

    sampler = get_sampler(args, data_files, dataset, classes,
                          labels_to_samples, mapping_files_to_global_id,
                          mapping_filename_path)

    dataloader = DataLoader(dataset,
                            batch_size=args.p * args.k,
                            sampler=sampler,
                            drop_last=True,
                            num_workers=args.num_workers)

    # define loss

    if args.margin == -1:
        criterion = TripletLoss(margin='soft', sample=False)
    else:
        criterion = TripletLoss(margin=args.margin, sample=False)

    # define optimizer

    if (args.weights is not None) & (bool(args.load_optim)):
        optimizer = torch.load(args.weights)['optimizer']
    else:
        parameters = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = Adam(parameters, lr=args.lr, weight_decay=args.wd)

    # define scheduler

    scheduler = get_scheduler(args, optimizer)
    model.train()

    # start training loop

    for epoch in tqdm(range(args.start_epoch, args.epochs + args.start_epoch)):
        params = {
            'model': model,
            'dataloader': dataloader,
            'optimizer': optimizer,
            'criterion': criterion,
            'logging_step': args.logging_step,
            'epoch': epoch,
            'epochs': args.epochs,
            'writer': writer,
            'time_id': time_id,
            'scheduler': scheduler,
            'output_folder': output_folder
        }
        _ = train(**params)

        scheduler.step()

    state = {'state_dict': model.state_dict()}
    torch.save(state, os.path.join(output_folder, f'{time_id}.pth'))

    compute_predictions(args, data_files, model, mapping_label_id, time_id,
                        output_folder)
def main():
    dataroot = args.dataroot
    lfw_dataroot = args.lfw
    dataset_csv = args.dataset_csv
    lfw_batch_size = args.lfw_batch_size
    lfw_validation_epoch_interval = args.lfw_validation_epoch_interval
    model_architecture = args.model
    epochs = args.epochs
    training_triplets_path = args.training_triplets_path
    num_triplets_train = args.num_triplets_train
    resume_path = args.resume_path
    batch_size = args.batch_size
    num_workers = args.num_workers
    embedding_dimension = args.embedding_dim
    pretrained = args.pretrained
    optimizer = args.optimizer
    learning_rate = args.lr
    margin = args.margin
    start_epoch = 0

    # Define image data pre-processing transforms
    #   ToTensor() normalizes pixel values between [0, 1]
    #   Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) normalizes pixel values between [-1, 1]

    #  Size 182x182 RGB image -> Center crop size 160x160 RGB image for more model generalization
    data_transforms = transforms.Compose([
        transforms.RandomCrop(size=160),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # Size 160x160 RGB image
    lfw_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Set dataloaders
    train_dataloader = torch.utils.data.DataLoader(dataset=TripletFaceDataset(
        root_dir=dataroot,
        csv_name=dataset_csv,
        num_triplets=num_triplets_train,
        training_triplets_path=training_triplets_path,
        transform=data_transforms),
                                                   batch_size=batch_size,
                                                   num_workers=num_workers,
                                                   shuffle=False)

    lfw_dataloader = torch.utils.data.DataLoader(dataset=LFWDataset(
        dir=lfw_dataroot,
        pairs_path='datasets/LFW_pairs.txt',
        transform=lfw_transforms),
                                                 batch_size=lfw_batch_size,
                                                 num_workers=num_workers,
                                                 shuffle=False)

    # Instantiate model
    if model_architecture == "resnet18":
        model = Resnet18Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet34":
        model = Resnet34Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet50":
        model = Resnet50Triplet(embedding_dimension=embedding_dimension,
                                pretrained=pretrained)
    elif model_architecture == "resnet101":
        model = Resnet101Triplet(embedding_dimension=embedding_dimension,
                                 pretrained=pretrained)
    elif model_architecture == "inceptionresnetv2":
        model = InceptionResnetV2Triplet(
            embedding_dimension=embedding_dimension, pretrained=pretrained)
    print("Using {} model architecture.".format(model_architecture))

    # Load model to GPU or multiple GPUs if available
    flag_train_gpu = torch.cuda.is_available()
    flag_train_multi_gpu = False

    if flag_train_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model.cuda()
        flag_train_multi_gpu = True
        print('Using multi-gpu training.')
    elif flag_train_gpu and torch.cuda.device_count() == 1:
        model.cuda()
        print('Using single-gpu training.')

    # Set optimizers
    if optimizer == "sgd":
        optimizer_model = torch.optim.SGD(model.parameters(), lr=learning_rate)

    elif optimizer == "adagrad":
        optimizer_model = torch.optim.Adagrad(model.parameters(),
                                              lr=learning_rate)

    elif optimizer == "rmsprop":
        optimizer_model = torch.optim.RMSprop(model.parameters(),
                                              lr=learning_rate)

    elif optimizer == "adam":
        optimizer_model = torch.optim.Adam(model.parameters(),
                                           lr=learning_rate)

    # Optionally resume from a checkpoint
    if resume_path:

        if os.path.isfile(resume_path):
            print("\nLoading checkpoint {} ...".format(resume_path))

            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch']

            # In order to load state dict for optimizers correctly, model has to be loaded to gpu first
            if flag_train_multi_gpu:
                model.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint['model_state_dict'])

            optimizer_model.load_state_dict(
                checkpoint['optimizer_model_state_dict'])

            print(
                "\nCheckpoint loaded: start epoch from checkpoint = {}\nRunning for {} epochs.\n"
                .format(start_epoch, epochs - start_epoch))
        else:
            print(
                "WARNING: No checkpoint found at {}!\nTraining from scratch.".
                format(resume_path))

    # Start Training loop
    print(
        "\nTraining using triplet loss on {} triplets starting for {} epochs:\n"
        .format(num_triplets_train, epochs - start_epoch))

    total_time_start = time.time()
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    l2_distance = PairwiseDistance(2).cuda()

    for epoch in range(start_epoch, end_epoch):
        epoch_time_start = time.time()

        flag_validate_lfw = (epoch +
                             1) % lfw_validation_epoch_interval == 0 or (
                                 epoch + 1) % epochs == 0
        triplet_loss_sum = 0
        num_valid_training_triplets = 0

        # Training pass
        model.train()
        progress_bar = enumerate(tqdm(train_dataloader))

        for batch_idx, (batch_sample) in progress_bar:

            anc_img = batch_sample['anc_img'].cuda()
            pos_img = batch_sample['pos_img'].cuda()
            neg_img = batch_sample['neg_img'].cuda()

            # Forward pass - compute embeddings
            anc_embedding, pos_embedding, neg_embedding = model(
                anc_img), model(pos_img), model(neg_img)

            # Forward pass - choose hard negatives only for training
            pos_dist = l2_distance.forward(anc_embedding, pos_embedding)
            neg_dist = l2_distance.forward(anc_embedding, neg_embedding)

            all = (neg_dist - pos_dist < margin).cpu().numpy().flatten()

            hard_triplets = np.where(all == 1)
            if len(hard_triplets[0]) == 0:
                continue

            anc_hard_embedding = anc_embedding[hard_triplets].cuda()
            pos_hard_embedding = pos_embedding[hard_triplets].cuda()
            neg_hard_embedding = neg_embedding[hard_triplets].cuda()

            # Calculate triplet loss
            triplet_loss = TripletLoss(margin=margin).forward(
                anchor=anc_hard_embedding,
                positive=pos_hard_embedding,
                negative=neg_hard_embedding).cuda()

            # Calculating loss
            triplet_loss_sum += triplet_loss.item()
            num_valid_training_triplets += len(anc_hard_embedding)

            # Backward pass
            optimizer_model.zero_grad()
            triplet_loss.backward()
            optimizer_model.step()

        # Model only trains on hard negative triplets
        avg_triplet_loss = 0 if (
            num_valid_training_triplets
            == 0) else triplet_loss_sum / num_valid_training_triplets
        epoch_time_end = time.time()

        # Print training statistics and add to log
        print(
            'Epoch {}:\tAverage Triplet Loss: {:.4f}\tEpoch Time: {:.3f} hours\tNumber of valid training triplets in epoch: {}'
            .format(epoch + 1, avg_triplet_loss,
                    (epoch_time_end - epoch_time_start) / 3600,
                    num_valid_training_triplets))
        with open('logs/{}_log_triplet.txt'.format(model_architecture),
                  'a') as f:
            val_list = [
                epoch + 1, avg_triplet_loss, num_valid_training_triplets
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

        try:
            # Plot Triplet losses plot
            plot_triplet_losses(
                log_dir="logs/{}_log_triplet.txt".format(model_architecture),
                epochs=epochs,
                figure_name="plots/triplet_losses_{}.png".format(
                    model_architecture))
        except Exception as e:
            print(e)

        # Evaluation pass on LFW dataset
        if flag_validate_lfw:

            model.eval()
            with torch.no_grad():
                distances, labels = [], []

                print("Validating on LFW! ...")
                progress_bar = enumerate(tqdm(lfw_dataloader))

                for batch_index, (data_a, data_b, label) in progress_bar:
                    data_a, data_b, label = data_a.cuda(), data_b.cuda(
                    ), label.cuda()

                    output_a, output_b = model(data_a), model(data_b)
                    distance = l2_distance.forward(
                        output_a, output_b)  # Euclidean distance

                    distances.append(distance.cpu().detach().numpy())
                    labels.append(label.cpu().detach().numpy())

                labels = np.array(
                    [sublabel for label in labels for sublabel in label])
                distances = np.array([
                    subdist for distance in distances for subdist in distance
                ])

                true_positive_rate, false_positive_rate, precision, recall, accuracy, roc_auc, best_distances, \
                    tar, far = evaluate_lfw(
                        distances=distances,
                        labels=labels
                    )

                # Print statistics and add to log
                print(
                    "Accuracy on LFW: {:.4f}+-{:.4f}\tPrecision {:.4f}+-{:.4f}\tRecall {:.4f}+-{:.4f}\tROC Area Under Curve: {:.4f}\tBest distance threshold: {:.2f}+-{:.2f}\tTAR: {:.4f}+-{:.4f} @ FAR: {:.4f}"
                    .format(np.mean(accuracy), np.std(accuracy),
                            np.mean(precision), np.std(precision),
                            np.mean(recall), np.std(recall), roc_auc,
                            np.mean(best_distances), np.std(best_distances),
                            np.mean(tar), np.std(tar), np.mean(far)))
                with open(
                        'logs/lfw_{}_log_triplet.txt'.format(
                            model_architecture), 'a') as f:
                    val_list = [
                        epoch + 1,
                        np.mean(accuracy),
                        np.std(accuracy),
                        np.mean(precision),
                        np.std(precision),
                        np.mean(recall),
                        np.std(recall), roc_auc,
                        np.mean(best_distances),
                        np.std(best_distances),
                        np.mean(tar)
                    ]
                    log = '\t'.join(str(value) for value in val_list)
                    f.writelines(log + '\n')

            try:
                # Plot ROC curve
                plot_roc_lfw(
                    false_positive_rate=false_positive_rate,
                    true_positive_rate=true_positive_rate,
                    figure_name="plots/roc_plots/roc_{}_epoch_{}_triplet.png".
                    format(model_architecture, epoch + 1))
                # Plot LFW accuracies plot
                plot_accuracy_lfw(
                    log_dir="logs/lfw_{}_log_triplet.txt".format(
                        model_architecture),
                    epochs=epochs,
                    figure_name="plots/lfw_accuracies_{}_triplet.png".format(
                        model_architecture))
            except Exception as e:
                print(e)

        # Save model checkpoint
        state = {
            'epoch': epoch + 1,
            'embedding_dimension': embedding_dimension,
            'batch_size_training': batch_size,
            'model_state_dict': model.state_dict(),
            'model_architecture': model_architecture,
            'optimizer_model_state_dict': optimizer_model.state_dict()
        }

        # For storing data parallel model's state dictionary without 'module' parameter
        if flag_train_multi_gpu:
            state['model_state_dict'] = model.module.state_dict()

        # For storing best euclidean distance threshold during LFW validation
        if flag_validate_lfw:
            state['best_distance_threshold'] = np.mean(best_distances)

        # Save model checkpoint
        torch.save(
            state,
            'Model_training_checkpoints/model_{}_triplet_epoch_{}.pt'.format(
                model_architecture, epoch + 1))

    # Training loop end
    total_time_end = time.time()
    total_time_elapsed = total_time_end - total_time_start
    print("\nTraining finished: total time elapsed: {:.2f} hours.".format(
        total_time_elapsed / 3600))