Пример #1
0
    def process_batch_base_class_classification_task(self, is_train):
        images = self.tensors["images_test"]
        labels = self.tensors["labels_test"]
        Kids = self.tensors["Kids"]
        base_ids = None if (self.nKbase
                            == 0) else Kids[:, :self.nKbase].contiguous()

        assert images.dim() == 5 and labels.dim() == 2
        images = utils.convert_from_5d_to_4d(images)
        labels = labels.view(-1)

        if self.optimizers.get("feature_extractor") is None:
            self.networks["feature_extractor"].eval()

        record = cls_utils.object_classification(
            feature_extractor=self.networks["feature_extractor"],
            feature_extractor_optimizer=self.optimizers["feature_extractor"],
            classifier=self.networks["classifier"],
            classifier_optimizer=self.optimizers["classifier"],
            images=images,
            labels=labels,
            is_train=is_train,
            base_ids=base_ids,
        )

        return record
Пример #2
0
    def process_batch_base_class_classification_task(self, is_train):

        images = self.tensors["images_test"]
        labels = self.tensors["labels_test"]
        Kids = self.tensors["Kids"]
        base_ids = Kids[:, :self.num_base].contiguous()
        assert images.dim() == 5 and labels.dim() == 2
        images = utils.convert_from_5d_to_4d(images)

        if self.standardize_image:
            images = utils.standardize_image(images)

        labels = labels.view(-1)

        patches = self.tensors["patches"]
        labels_patches = self.tensors["labels_patches"]

        auxiliary_tasks = is_train and (
            self.patch_location_loss_coef > 0.0
            or self.patch_classification_loss_coef > 0.0)

        if auxiliary_tasks:
            record = loc_utils.object_classification_with_patch_location_selfsupervision(
                feature_extractor=self.networks["feature_extractor"],
                feature_extractor_optimizer=self.
                optimizers["feature_extractor"],
                classifier=self.networks["classifier"],
                classifier_optimizer=self.optimizers["classifier"],
                location_classifier=self.networks.get("classifier_loc"),
                location_classifier_optimizer=self.optimizers.get(
                    "classifier_loc"),
                patch_classifier=self.networks.get("patch_classifier"),
                patch_classifier_optimizer=self.optimizers.get(
                    "patch_classifier"),
                images=images,
                labels=labels,
                patches=patches,
                labels_patches=labels_patches,
                is_train=is_train,
                patch_location_loss_coef=self.patch_location_loss_coef,
                patch_classification_loss_coef=self.
                patch_classification_loss_coef,
                combine=self.combine_patches,
                base_ids=base_ids,
                standardize_patches=self.standardize_patches,
            )
        else:
            record = cls_utils.object_classification(
                feature_extractor=self.networks["feature_extractor"],
                feature_extractor_optimizer=self.
                optimizers["feature_extractor"],
                classifier=self.networks["classifier"],
                classifier_optimizer=self.optimizers["classifier"],
                images=images,
                labels=labels,
                is_train=is_train,
                base_ids=base_ids,
            )

        return record
Пример #3
0
    def process_batch_base_class_classification_task(self,
                                                     auxiliary_rotation_task,
                                                     is_train):

        images = self.tensors["images_test"]
        labels = self.tensors["labels_test"]
        Kids = self.tensors["Kids"]
        assert images.dim() == 5 and labels.dim() == 2
        images = utils.convert_from_5d_to_4d(images)
        labels = labels.view(-1)

        if self.semi_supervised and is_train:
            images_unlabeled = self.tensors["images_unlabeled"]
            assert images_unlabeled.dim() == 4
            assert auxiliary_rotation_task is True
        else:
            images_unlabeled = None

        if auxiliary_rotation_task:
            record = rot_utils.object_classification_with_rotation_selfsupervision(
                feature_extractor=self.networks["feature_extractor"],
                feature_extractor_optimizer=self.
                optimizers["feature_extractor"],
                classifier=self.networks["classifier"],
                classifier_optimizer=self.optimizers["classifier"],
                classifier_rot=self.networks["classifier_aux"],
                classifier_rot_optimizer=self.optimizers["classifier_aux"],
                images=images,
                labels=labels,
                is_train=is_train,
                alpha=self.auxiliary_rotation_task_coef,
                random_rotation=self.random_rotation,
                rotation_invariant_classifier=self.
                rotation_invariant_classifier,
                base_ids=Kids[:, :self.num_base].contiguous(),
                feature_name=self.feature_name,
                images_unlabeled=images_unlabeled,
            )
        else:
            record = rot_utils.object_classification_rotation_invariant(
                feature_extractor=self.networks["feature_extractor"],
                feature_extractor_optimizer=self.
                optimizers["feature_extractor"],
                classifier=self.networks["classifier"],
                classifier_optimizer=self.optimizers["classifier"],
                images=images,
                labels=labels,
                is_train=is_train,
                rotation_invariant_classifier=self.
                rotation_invariant_classifier,
                random_rotation=self.random_rotation,
                base_ids=Kids[:, :self.num_base].contiguous(),
                feature_name=self.feature_name,
            )

        return record
Пример #4
0
    def add_novel_categories(self, nove_cat_training_data):
        """Add the training data of the novel categories to the model."""

        feature_extractor = self.networks["feature_extractor"]
        classifier = self.networks["classifier"]
        feature_extractor.eval()
        classifier.eval()

        self.preprocess_novel_training_data(nove_cat_training_data)

        images = self.tensors["images_train"].detach()
        labels_train_1hot = self.tensors["labels_train_1hot"].detach()
        Kids = self.tensors["Kids"].detach()
        base_ids = None if (self.num_base
                            == 0) else Kids[:, :self.num_base].contiguous()

        with torch.no_grad():
            # *******************************************************************
            # ****************** EXTRACT FEATS FROM EXEMPLARS *******************
            meta_batch_size = images.size(0)
            images = utils.convert_from_5d_to_4d(images)
            features_train = feature_extractor(images)
            features_train = utils.add_dimension(features_train,
                                                 meta_batch_size)

            # *******************************************************************
            # ****************** GET CLASSIFICATION WEIGHTS *********************
            # The following routine returns the classification weight vectors of
            # both the base and then novel categories. For the novel categories,
            # the classification weight vectors are generated using the training
            # features for those novel cateogories.
            clsWeights = classifier.get_classification_weights(
                base_ids=base_ids,
                features_train=features_train,
                labels_train=labels_train_1hot,
            )
            # *******************************************************************

        self.tensors["clsWeights"] = clsWeights.clone().detach()
Пример #5
0
def fewshot_classification_with_patch_location_selfsupervision(
    feature_extractor,
    feature_extractor_optimizer,
    classifier,
    classifier_optimizer,
    location_classifier,
    location_classifier_optimizer,
    patch_classifier,
    patch_classifier_optimizer,
    images_train,
    patches_train,
    labels_train,
    labels_train_1hot,
    images_test,
    patches_test,
    labels_test,
    is_train,
    base_ids=None,
    patch_location_loss_coef=1.0,
    patch_classification_loss_coef=1.0,
    combine="average",
    standardize_patches=True,
):
    """Forward-backward propagation routine for few-shot model extended
    with the auxiliary self-supervised task of predicting the relative location
    of patches."""

    assert images_train.dim() == 5
    assert images_test.dim() == 5
    assert images_train.size(0) == images_test.size(0)
    assert images_train.size(2) == images_test.size(2)
    assert images_train.size(3) == images_test.size(3)
    assert images_train.size(4) == images_test.size(4)
    assert labels_train.dim() == 2
    assert labels_test.dim() == 2
    assert labels_train.size(0) == labels_test.size(0)
    assert labels_train.size(0) == images_train.size(0)

    assert patches_train.dim() == 6
    assert patches_train.size(0) == images_train.size(0)
    assert patches_train.size(1) == images_train.size(1)
    assert patches_train.size(2) == 9

    assert patches_test.dim() == 6
    assert patches_test.size(0) == images_test.size(0)
    assert patches_test.size(1) == images_test.size(1)
    assert patches_test.size(2) == 9

    meta_batch_size = images_train.size(0)
    num_train = images_train.size(1)
    num_test = images_test.size(1)

    if is_train:
        # Zero gradients.
        feature_extractor_optimizer.zero_grad()
        classifier_optimizer.zero_grad()
        if patch_location_loss_coef > 0.0:
            location_classifier_optimizer.zero_grad()
        if patch_classification_loss_coef > 0.0:
            patch_classifier_optimizer.zero_grad()

    record = {}
    with torch.no_grad():
        images_train = utils.convert_from_5d_to_4d(images_train)
        images_test = utils.convert_from_5d_to_4d(images_test)
        labels_test = labels_test.view(-1)
        images = torch.cat([images_train, images_test], dim=0)

        batch_size_train = images_train.size(0)
        batch_size_train_test = images.size(0)
        assert batch_size_train == meta_batch_size * num_train
        assert batch_size_train_test == meta_batch_size * (num_train +
                                                           num_test)

        patches_train = utils.convert_from_6d_to_4d(patches_train)
        patches_test = utils.convert_from_6d_to_4d(patches_test)
        if standardize_patches:
            patches_train = utils.standardize_image(patches_train)
            patches_test = utils.standardize_image(patches_test)
        patches = torch.cat([patches_train, patches_test], dim=0)

        assert patches_train.size(0) == batch_size_train * 9
        assert patches.size(0) == batch_size_train_test * 9

    with torch.set_grad_enabled(is_train):
        # Extract features from the images.
        features = feature_extractor(images)
        # Extract features from the image patches.
        features_patches = feature_extractor(patches)

        # Perform object classification task.
        features_train = features[:batch_size_train]
        features_test = features[batch_size_train:batch_size_train_test]
        features_train = utils.add_dimension(features_train, meta_batch_size)
        features_test = utils.add_dimension(features_test, meta_batch_size)
        (
            classification_scores,
            loss_classsification,
        ) = fewshot_utils.few_shot_feature_classification(
            classifier,
            features_test,
            features_train,
            labels_train_1hot,
            labels_test,
            base_ids,
        )
        record["loss_cls"] = loss_classsification.item()
        loss_total = loss_classsification

        # Perform the self-supervised task of relative patch locatioon
        # prediction.
        if patch_location_loss_coef > 0.0:
            scores_location, loss_location, labels_loc = patch_location_task(
                location_classifier, features_patches)
            record["loss_loc"] = loss_location.item()
            loss_total = loss_total + loss_location * patch_location_loss_coef

        # Perform the auxiliary task of classifying patches.
        if patch_classification_loss_coef > 0.0:
            features_patches = add_patch_dimension(features_patches)
            assert features_patches.size(0) == batch_size_train_test
            assert features_patches.size(1) == 9
            features_patches = combine_multiple_patch_features(
                features_patches, combine)

            features_patches_train = utils.add_dimension(
                features_patches[:batch_size_train], meta_batch_size)
            features_patches_test = utils.add_dimension(
                features_patches[batch_size_train:batch_size_train_test],
                meta_batch_size)

            scores_patch, loss_patch = fewshot_utils.few_shot_feature_classification(
                patch_classifier,
                features_patches_test,
                features_patches_train,
                labels_train_1hot,
                labels_test,
                base_ids,
            )
            record["loss_patch_cls"] = loss_patch.item()
            loss_total = loss_total + loss_patch * patch_classification_loss_coef

        # Because the total loss consists of multiple individual losses
        # (i.e., 3) scale it down by a factor of 0.5.
        loss_total = loss_total * 0.5

    with torch.no_grad():
        num_base = base_ids.size(1) if (base_ids is not None) else 0
        record = fewshot_utils.compute_accuracy_metrics(
            classification_scores, labels_test, num_base, record)
        if patch_location_loss_coef > 0.0:
            record["AccuracyLoc"] = utils.top1accuracy(scores_location,
                                                       labels_loc)
        if patch_classification_loss_coef > 0.0:
            record = fewshot_utils.compute_accuracy_metrics(scores_patch,
                                                            labels_test,
                                                            num_base,
                                                            record,
                                                            string_id="Patch")

    if is_train:
        # Backward loss and apply gradient steps.
        loss_total.backward()
        feature_extractor_optimizer.step()
        classifier_optimizer.step()
        if patch_location_loss_coef > 0.0:
            location_classifier_optimizer.step()
        if patch_classification_loss_coef > 0.0:
            patch_classifier_optimizer.step()

    return record
Пример #6
0
def object_classification_with_patch_location_selfsupervision(
    feature_extractor,
    feature_extractor_optimizer,
    classifier,
    classifier_optimizer,
    location_classifier,
    location_classifier_optimizer,
    patch_classifier,
    patch_classifier_optimizer,
    images,
    labels,
    patches,
    labels_patches,
    is_train,
    patch_location_loss_coef=1.0,
    patch_classification_loss_coef=1.0,
    combine="average",
    base_ids=None,
    standardize_patches=True,
):
    """Forward-backward propagation routine for classification model extended
    with the auxiliary self-supervised task of predicting the relative location
    of patches."""

    if base_ids is not None:
        assert base_ids.size(0) == 1

    assert images.dim() == 4
    assert labels.dim() == 1
    assert images.size(0) == labels.size(0)

    assert patches.dim() == 5 and patches.size(1) == 9
    assert patches.size(0) == labels_patches.size(0)
    patches = utils.convert_from_5d_to_4d(patches)
    if standardize_patches:
        patches = utils.standardize_image(patches)

    if is_train:
        # Zero gradients.
        feature_extractor_optimizer.zero_grad()
        classifier_optimizer.zero_grad()
        if patch_location_loss_coef > 0.0:
            location_classifier_optimizer.zero_grad()
        if patch_classification_loss_coef > 0.0:
            patch_classifier_optimizer.zero_grad()

    record = {}
    with torch.set_grad_enabled(is_train):
        # Extract features from the images.
        features_images = feature_extractor(images)
        # Extract features from the image patches.
        features_patches = feature_extractor(patches)

        # Perform object classification task.
        scores_classification, loss_classsification = cls_utils.classification_task(
            classifier, features_images, labels, base_ids)
        record["loss_cls"] = loss_classsification.item()
        loss_total = loss_classsification

        # Perform the self-supervised task of relative patch locatioon
        # prediction.
        if patch_location_loss_coef > 0.0:
            scores_location, loss_location, labels_loc = patch_location_task(
                location_classifier, features_patches)
            record["loss_loc"] = loss_location.item()
            loss_total = loss_total + loss_location * patch_location_loss_coef

        # Perform the auxiliary task of classifying individual patches.
        if patch_classification_loss_coef > 0.0:
            scores_patch, loss_patch = patch_classification_task(
                patch_classifier, features_patches, labels_patches, combine)
            record["loss_patch_cls"] = loss_patch.item()
            loss_total = loss_total + loss_patch * patch_classification_loss_coef

        # Because the total loss consists of multiple individual losses
        # (i.e., 3) scale it down by a factor of 0.5.
        loss_total = loss_total * 0.5

    with torch.no_grad():
        # Compute accuracies.
        record["Accuracy"] = utils.top1accuracy(scores_classification, labels)
        if patch_location_loss_coef > 0.0:
            record["AccuracyLoc"] = utils.top1accuracy(scores_location,
                                                       labels_loc)
        if patch_classification_loss_coef > 0.0:
            record["AccuracyPatch"] = utils.top1accuracy(
                scores_patch, labels_patches)

    if is_train:
        # Backward loss and apply gradient steps.
        loss_total.backward()
        feature_extractor_optimizer.step()
        classifier_optimizer.step()
        if patch_location_loss_coef > 0.0:
            location_classifier_optimizer.step()
        if patch_classification_loss_coef > 0.0:
            patch_classifier_optimizer.step()

    return record
Пример #7
0
def fewshot_classification_with_rotation_selfsupervision(
    feature_extractor,
    feature_extractor_optimizer,
    classifier,
    classifier_optimizer,
    classifier_rot,
    classifier_rot_optimizer,
    images_train,
    labels_train,
    labels_train_1hot,
    images_test,
    labels_test,
    is_train,
    alpha=1.0,
    base_ids=None,
    feature_name=None,
):
    """Forward-backward routine of a few-shot model with auxiliary rotation
    prediction task.

    Given as input a mini-batch of few-shot episodes, it applies the
    forward and (optionally) backward propagation routines of the few-shot
    classification task. Each episode consists of (1) num_train_examples number
    of training examples for the novel classes of the few-shot episode, (2) the
    labels of training examples of the novel classes, (3) num_test_examples
    number of test examples of the few-shot episode (note that the test
    examples can be from both base and novel classes), and (4) the labels of the
    test examples. Each mini-batch consists of meta_batch_size number of
    few-shot episodes. The code assumes that the few-shot classification model
    is divided into a feature extractor network and a classification head
    network. Also, the code applies the auxiliary self-supervised task of
    predicting image rotations. The rotation prediction task is applied on both
    the test and train examples of the few-shot episodes.

    Args:
    feature_extractor: The feature extractor neural network.
    feature_extractor_optimizer: The parameter optimizer of the feature
        extractor. If None, then the feature extractor remains frozen during
        training.
    classifier: The classification head applied on the output of the feature
        extractor.
    classifier_optimizer: The parameter optimizer of the classification head.
    classifier_rot: The rotation prediction head applied on the output of the
        feature extractor.
    classifier_rot_optimizer: The parameter optimizer of the rotation prediction
        head.
    images_train: A 5D tensor with shape
        [meta_batch_size x num_train_examples x channels x height x width] that
        represents a mini-batch of meta_batch_size number of few-shot episodes,
        each with num_train_examples number of training examples.
    labels_train: A 2D tensor with shape
        [meta_batch_size x num_train_examples] that represents the discrete
        labels of the training examples of each few-shot episode in the batch.
    labels_train_1hot: A 3D tensor with shape
        [meta_batch_size x num_train_examples x num_novel] that represents
        the 1hot labels of the training examples of the novel classes of each
        few-shot episode in the batch. num_novel is the number of novel classes
        per few-shot episode.
    images_test: A 5D tensor with shape
        [meta_batch_size x num_test_examples x channels x height x width] that
        represents a mini-batch of meta_batch_size number of few-shot episodes,
        each with num_test_examples number of test examples.
    labels_test: A 2D tensor with shape
        [meta_batch_size x num_test_examples] that represents the discrete
        labels of the test examples of each few-shot episode in the mini-batch.
    is_train: Boolean value that indicates if this mini-batch will be
        used for training or testing. If is_train is False, then the code does
        not apply the backward propagation step and does not update the
        parameter optimizers.
    base_ids: A 2D tensor with shape [meta_batch_size x num_base], where
        base_ids[m] are the indices of the base categories that are being used
        in the m-th few-shot episode. num_base is the number of base classes per
        few-shot episode.
    alpha: (optional) The loss weight of the rotation prediction task.
    feature_name: (optional) A string or list of strings with the name of
        feature level(s) from which the feature extractor will extract features
        for the classification task.

    Returns:
    record: A dictionary of scalar values with the following items:
        'loss_cls': The cross entropy loss of the few-shot classification task.
        'loss_rot': The rotation prediction loss.
        'loss_total': The total loss, i.e., loss_cls + alpha * loss_rot.
        'AccuracyNovel': The classification accuracy of the test examples among
            only the novel classes.
        'AccuracyBase': (optinional) The classification accuracy of the test
            examples among only the base classes. Applicable, only if there are
            test examples from base classes in the mini-batch.
        'AccuracyBase': (optinional) The classification accuracy of the test
            examples among both the base and novel classes. Applicable, only if
            there are test examples from base classes in the mini-batch.
        'AccuracyRot': The accuracy of the rotation prediction task.
    """

    assert images_train.dim() == 5
    assert images_test.dim() == 5
    assert images_train.size(0) == images_test.size(0)
    assert images_train.size(2) == images_test.size(2)
    assert images_train.size(3) == images_test.size(3)
    assert images_train.size(4) == images_test.size(4)
    assert labels_train.dim() == 2
    assert labels_test.dim() == 2
    assert labels_train.size(0) == labels_test.size(0)
    assert labels_train.size(0) == images_train.size(0)

    meta_batch_size = images_train.size(0)
    num_train = images_train.size(1)
    num_test = images_test.size(1)

    if is_train:  # zero the gradients
        feature_extractor_optimizer.zero_grad()
        classifier_optimizer.zero_grad()
        classifier_rot_optimizer.zero_grad()

    record = {}
    with torch.no_grad():
        images_train = utils.convert_from_5d_to_4d(images_train)
        images_test = utils.convert_from_5d_to_4d(images_test)
        labels_test = labels_test.view(-1)
        images = torch.cat([images_train, images_test], dim=0)

        batch_size_train = images_train.size(0)
        batch_size_train_test = images.size(0)
        assert batch_size_train == meta_batch_size * num_train
        assert batch_size_train_test == meta_batch_size * (num_train +
                                                           num_test)

        # Create the 4 rotated version of the images; this step increases
        # the batch size by a multiple of 4.
        images = create_4rotations_images(images)
        labels_rotation = create_rotations_labels(batch_size_train_test,
                                                  images.device)

    with torch.set_grad_enabled(is_train):
        # Extract features from the train and test images.
        features = cls_utils.extract_features(feature_extractor,
                                              images,
                                              feature_name=feature_name)

        # Apply the few-shot classification head.
        features_train = features[:batch_size_train]
        features_test = features[batch_size_train:batch_size_train_test]
        features_train = utils.add_dimension(features_train, meta_batch_size)
        features_test = utils.add_dimension(features_test, meta_batch_size)
        (
            classification_scores,
            loss_classsification,
        ) = fewshot_utils.few_shot_feature_classification(
            classifier,
            features_test,
            features_train,
            labels_train_1hot,
            labels_test,
            base_ids,
        )
        record["loss_cls"] = loss_classsification.item()

        # Apply the rotation prediction head.
        scores_rotation, loss_rotation = rotation_task(classifier_rot,
                                                       features,
                                                       labels_rotation)
        record["loss_rot"] = loss_rotation.item()

        # Compute total loss.
        loss_total = loss_classsification + alpha * loss_rotation
        record["loss_total"] = loss_total.item()

    with torch.no_grad():
        num_base = base_ids.size(1) if (base_ids is not None) else 0
        record = fewshot_utils.compute_accuracy_metrics(
            classification_scores, labels_test, num_base, record)
        record["AccuracyRot"] = utils.top1accuracy(scores_rotation,
                                                   labels_rotation)

    if is_train:
        loss_total.backward()
        feature_extractor_optimizer.step()
        classifier_optimizer.step()
        classifier_rot_optimizer.step()

    return record
Пример #8
0
def fewshot_classification(
    feature_extractor,
    feature_extractor_optimizer,
    classifier,
    classifier_optimizer,
    images_train,
    labels_train,
    labels_train_1hot,
    images_test,
    labels_test,
    is_train,
    base_ids=None,
    feature_name=None,
    classification_coef=1.0,
):
    """Forward-backward propagation routine of the few-shot classification task.

    Given as input a mini-batch of few-shot episodes, it applies the
    forward and (optionally) backward propagation routines of the few-shot
    classification task. Each episode consists of (1) num_train_examples number
    of training examples for the novel classes of the few-shot episode, (2) the
    labels of training examples of the novel classes, (3) num_test_examples
    number of test examples of the few-shot episode (note that the test
    examples can be from both base and novel classes), and (4) the labels of the
    test examples. Each mini-batch consists of meta_batch_size number of
    few-shot episodes. The code assumes that the few-shot classification model
    is divided into a feature extractor network and a classification head
    network.

    Args:
    feature_extractor: The feature extractor neural network.
    feature_extractor_optimizer: The parameter optimizer of the feature
        extractor. If None, then the feature extractor remains frozen during
        training.
    classifier: The classification head applied on the output of the feature
        extractor.
    classifier_optimizer: The parameter optimizer of the classification head.
    images_train: A 5D tensor with shape
        [meta_batch_size x num_train_examples x channels x height x width] that
        represents a mini-batch of meta_batch_size number of few-shot episodes,
        each with num_train_examples number of training examples.
    labels_train: A 2D tensor with shape
        [meta_batch_size x num_train_examples] that represents the discrete
        labels of the training examples of each few-shot episode in the batch.
    labels_train_1hot: A 3D tensor with shape
        [meta_batch_size x num_train_examples x num_novel] that represents
        the 1hot labels of the training examples of the novel classes of each
        few-shot episode in the batch. num_novel is the number of novel classes
        per few-shot episode.
    images_test: A 5D tensor with shape
        [meta_batch_size x num_test_examples x channels x height x width] that
        represents a mini-batch of meta_batch_size number of few-shot episodes,
        each with num_test_examples number of test examples.
    labels_test: A 2D tensor with shape
        [meta_batch_size x num_test_examples] that represents the discrete
        labels of the test examples of each few-shot episode in the mini-batch.
    is_train: Boolean value that indicates if this mini-batch will be
        used for training or testing. If is_train is False, then the code does
        not apply the backward propagation step and does not update the
        parameter optimizers.
    base_ids: A 2D tensor with shape [meta_batch_size x num_base], where
        base_ids[m] are the indices of the base categories that are being used
        in the m-th few-shot episode. num_base is the number of base classes per
        few-shot episode.
    feature_name: (optional) A string or list of strings with the name of
        feature level(s) from which the feature extractor will extract features
        for the classification task.
    classification_coef: (optional) the loss weight of the few-shot
        classification task.

    Returns:
    record: A dictionary of scalar values with the following items:
        'loss': The cross entropy loss of this mini-batch.
        'AccuracyNovel': The classification accuracy of the test examples among
            only the novel classes.
        'AccuracyBase': (optinional) The classification accuracy of the test
            examples among only the base classes. Applicable, only if there are
            test examples from base classes in the mini-batch.
        'AccuracyBase': (optinional) The classification accuracy of the test
            examples among both the base and novel classes. Applicable, only if
            there are test examples from base classes in the mini-batch.
    """

    assert images_train.dim() == 5
    assert images_test.dim() == 5
    assert images_train.size(0) == images_test.size(0)
    assert images_train.size(2) == images_test.size(2)
    assert images_train.size(3) == images_test.size(3)
    assert images_train.size(4) == images_test.size(4)
    assert labels_train.dim() == 2
    assert labels_test.dim() == 2
    assert labels_train.size(0) == labels_test.size(0)
    assert labels_train.size(0) == images_train.size(0)

    assert not (isinstance(feature_name,
                           (list, tuple)) and len(feature_name) > 1)

    meta_batch_size = images_train.size(0)

    if is_train:  # zero the gradients
        if feature_extractor_optimizer:
            feature_extractor_optimizer.zero_grad()
        classifier_optimizer.zero_grad()

    record = {}
    with torch.no_grad():
        images_train = utils.convert_from_5d_to_4d(images_train)
        images_test = utils.convert_from_5d_to_4d(images_test)
        labels_test = labels_test.view(-1)
        batch_size_train = images_train.size(0)
        # batch_size_test = images_test.size(0)
        images = torch.cat([images_train, images_test], dim=0)

    train_feature_extractor = is_train and (feature_extractor_optimizer
                                            is not None)
    with torch.set_grad_enabled(train_feature_extractor):
        # Extract features from the train and test images.
        features = cls_utils.extract_features(feature_extractor,
                                              images,
                                              feature_name=feature_name)

    if not train_feature_extractor:
        # Make sure that no gradients are backproagated to the feature
        # extractor when the feature extraction model is freezed.
        features = features.detach()

    with torch.set_grad_enabled(is_train):
        features_train = features[:batch_size_train]
        features_test = features[batch_size_train:]
        features_train = utils.add_dimension(features_train, meta_batch_size)
        features_test = utils.add_dimension(features_test, meta_batch_size)

        # Apply the classification head of the few-shot classification model.
        classification_scores, loss = few_shot_feature_classification(
            classifier,
            features_test,
            features_train,
            labels_train_1hot,
            labels_test,
            base_ids,
        )
        record["loss"] = loss.item()
        loss_total = loss * classification_coef
        # *******************************************************************

    with torch.no_grad():
        num_base = base_ids.size(1) if (base_ids is not None) else 0
        record = compute_accuracy_metrics(classification_scores, labels_test,
                                          num_base, record)

    if is_train:
        loss_total.backward()
        if feature_extractor_optimizer:
            feature_extractor_optimizer.step()
        classifier_optimizer.step()

    return record
Пример #9
0
    def process_batch_fewshot_classification_task(self, is_train):
        Kids = self.tensors["Kids"]
        base_ids = None if (self.num_base
                            == 0) else Kids[:, :self.num_base].contiguous()

        images_train = self.tensors["images_train"]
        images_test = self.tensors["images_test"]

        if self.standardize_image:
            assert images_train.dim() == 5 and images_test.dim() == 5
            assert images_train.size(0) == images_test.size(0)
            meta_batch_size = images_train.size(0)
            images_train = utils.convert_from_5d_to_4d(images_train)
            images_test = utils.convert_from_5d_to_4d(images_test)

            images_train = utils.standardize_image(images_train)
            images_test = utils.standardize_image(images_test)

            images_train = utils.add_dimension(images_train, meta_batch_size)
            images_test = utils.add_dimension(images_test, meta_batch_size)

        auxiliary_tasks = is_train and (
            self.patch_location_loss_coef > 0.0
            or self.patch_classification_loss_coef > 0.0)

        if auxiliary_tasks:
            record = loc_utils.fewshot_classification_with_patch_location_selfsupervision(
                feature_extractor=self.networks["feature_extractor"],
                feature_extractor_optimizer=self.
                optimizers["feature_extractor"],
                classifier=self.networks["classifier"],
                classifier_optimizer=self.optimizers["classifier"],
                location_classifier=self.networks.get("classifier_loc"),
                location_classifier_optimizer=self.optimizers.get(
                    "classifier_loc"),
                patch_classifier=self.networks.get("patch_classifier"),
                patch_classifier_optimizer=self.optimizers.get(
                    "patch_classifier"),
                images_train=images_train,
                patches_train=self.tensors["patches_train"],
                labels_train=self.tensors["labels_train"],
                labels_train_1hot=self.tensors["labels_train_1hot"],
                images_test=images_test,
                patches_test=self.tensors["patches_test"],
                labels_test=self.tensors["labels_test"],
                is_train=is_train,
                base_ids=base_ids,
                patch_location_loss_coef=self.patch_location_loss_coef,
                patch_classification_loss_coef=self.
                patch_classification_loss_coef,
                combine=self.combine_patches,
                standardize_patches=self.standardize_patches,
            )
        else:
            record = fs_utils.fewshot_classification(
                feature_extractor=self.networks["feature_extractor"],
                feature_extractor_optimizer=self.optimizers.get(
                    "feature_extractor"),
                classifier=self.networks["classifier"],
                classifier_optimizer=self.optimizers.get("classifier"),
                images_train=images_train,
                labels_train=self.tensors["labels_train"],
                labels_train_1hot=self.tensors["labels_train_1hot"],
                images_test=images_test,
                labels_test=self.tensors["labels_test"],
                is_train=is_train,
                base_ids=base_ids,
            )

        if not is_train:
            record, self.accuracies = fs_utils.compute_95confidence_intervals(
                record,
                episode=self.biter,
                num_episodes=self.bnumber,
                store_accuracies=self.accuracies,
                metrics=[
                    "AccuracyNovel",
                ],
            )

        return record