def process_batch_base_class_classification_task(self, is_train): images = self.tensors["images_test"] labels = self.tensors["labels_test"] Kids = self.tensors["Kids"] base_ids = None if (self.nKbase == 0) else Kids[:, :self.nKbase].contiguous() assert images.dim() == 5 and labels.dim() == 2 images = utils.convert_from_5d_to_4d(images) labels = labels.view(-1) if self.optimizers.get("feature_extractor") is None: self.networks["feature_extractor"].eval() record = cls_utils.object_classification( feature_extractor=self.networks["feature_extractor"], feature_extractor_optimizer=self.optimizers["feature_extractor"], classifier=self.networks["classifier"], classifier_optimizer=self.optimizers["classifier"], images=images, labels=labels, is_train=is_train, base_ids=base_ids, ) return record
def process_batch_base_class_classification_task(self, is_train): images = self.tensors["images_test"] labels = self.tensors["labels_test"] Kids = self.tensors["Kids"] base_ids = Kids[:, :self.num_base].contiguous() assert images.dim() == 5 and labels.dim() == 2 images = utils.convert_from_5d_to_4d(images) if self.standardize_image: images = utils.standardize_image(images) labels = labels.view(-1) patches = self.tensors["patches"] labels_patches = self.tensors["labels_patches"] auxiliary_tasks = is_train and ( self.patch_location_loss_coef > 0.0 or self.patch_classification_loss_coef > 0.0) if auxiliary_tasks: record = loc_utils.object_classification_with_patch_location_selfsupervision( feature_extractor=self.networks["feature_extractor"], feature_extractor_optimizer=self. optimizers["feature_extractor"], classifier=self.networks["classifier"], classifier_optimizer=self.optimizers["classifier"], location_classifier=self.networks.get("classifier_loc"), location_classifier_optimizer=self.optimizers.get( "classifier_loc"), patch_classifier=self.networks.get("patch_classifier"), patch_classifier_optimizer=self.optimizers.get( "patch_classifier"), images=images, labels=labels, patches=patches, labels_patches=labels_patches, is_train=is_train, patch_location_loss_coef=self.patch_location_loss_coef, patch_classification_loss_coef=self. patch_classification_loss_coef, combine=self.combine_patches, base_ids=base_ids, standardize_patches=self.standardize_patches, ) else: record = cls_utils.object_classification( feature_extractor=self.networks["feature_extractor"], feature_extractor_optimizer=self. optimizers["feature_extractor"], classifier=self.networks["classifier"], classifier_optimizer=self.optimizers["classifier"], images=images, labels=labels, is_train=is_train, base_ids=base_ids, ) return record
def process_batch_base_class_classification_task(self, auxiliary_rotation_task, is_train): images = self.tensors["images_test"] labels = self.tensors["labels_test"] Kids = self.tensors["Kids"] assert images.dim() == 5 and labels.dim() == 2 images = utils.convert_from_5d_to_4d(images) labels = labels.view(-1) if self.semi_supervised and is_train: images_unlabeled = self.tensors["images_unlabeled"] assert images_unlabeled.dim() == 4 assert auxiliary_rotation_task is True else: images_unlabeled = None if auxiliary_rotation_task: record = rot_utils.object_classification_with_rotation_selfsupervision( feature_extractor=self.networks["feature_extractor"], feature_extractor_optimizer=self. optimizers["feature_extractor"], classifier=self.networks["classifier"], classifier_optimizer=self.optimizers["classifier"], classifier_rot=self.networks["classifier_aux"], classifier_rot_optimizer=self.optimizers["classifier_aux"], images=images, labels=labels, is_train=is_train, alpha=self.auxiliary_rotation_task_coef, random_rotation=self.random_rotation, rotation_invariant_classifier=self. rotation_invariant_classifier, base_ids=Kids[:, :self.num_base].contiguous(), feature_name=self.feature_name, images_unlabeled=images_unlabeled, ) else: record = rot_utils.object_classification_rotation_invariant( feature_extractor=self.networks["feature_extractor"], feature_extractor_optimizer=self. optimizers["feature_extractor"], classifier=self.networks["classifier"], classifier_optimizer=self.optimizers["classifier"], images=images, labels=labels, is_train=is_train, rotation_invariant_classifier=self. rotation_invariant_classifier, random_rotation=self.random_rotation, base_ids=Kids[:, :self.num_base].contiguous(), feature_name=self.feature_name, ) return record
def add_novel_categories(self, nove_cat_training_data): """Add the training data of the novel categories to the model.""" feature_extractor = self.networks["feature_extractor"] classifier = self.networks["classifier"] feature_extractor.eval() classifier.eval() self.preprocess_novel_training_data(nove_cat_training_data) images = self.tensors["images_train"].detach() labels_train_1hot = self.tensors["labels_train_1hot"].detach() Kids = self.tensors["Kids"].detach() base_ids = None if (self.num_base == 0) else Kids[:, :self.num_base].contiguous() with torch.no_grad(): # ******************************************************************* # ****************** EXTRACT FEATS FROM EXEMPLARS ******************* meta_batch_size = images.size(0) images = utils.convert_from_5d_to_4d(images) features_train = feature_extractor(images) features_train = utils.add_dimension(features_train, meta_batch_size) # ******************************************************************* # ****************** GET CLASSIFICATION WEIGHTS ********************* # The following routine returns the classification weight vectors of # both the base and then novel categories. For the novel categories, # the classification weight vectors are generated using the training # features for those novel cateogories. clsWeights = classifier.get_classification_weights( base_ids=base_ids, features_train=features_train, labels_train=labels_train_1hot, ) # ******************************************************************* self.tensors["clsWeights"] = clsWeights.clone().detach()
def fewshot_classification_with_patch_location_selfsupervision( feature_extractor, feature_extractor_optimizer, classifier, classifier_optimizer, location_classifier, location_classifier_optimizer, patch_classifier, patch_classifier_optimizer, images_train, patches_train, labels_train, labels_train_1hot, images_test, patches_test, labels_test, is_train, base_ids=None, patch_location_loss_coef=1.0, patch_classification_loss_coef=1.0, combine="average", standardize_patches=True, ): """Forward-backward propagation routine for few-shot model extended with the auxiliary self-supervised task of predicting the relative location of patches.""" assert images_train.dim() == 5 assert images_test.dim() == 5 assert images_train.size(0) == images_test.size(0) assert images_train.size(2) == images_test.size(2) assert images_train.size(3) == images_test.size(3) assert images_train.size(4) == images_test.size(4) assert labels_train.dim() == 2 assert labels_test.dim() == 2 assert labels_train.size(0) == labels_test.size(0) assert labels_train.size(0) == images_train.size(0) assert patches_train.dim() == 6 assert patches_train.size(0) == images_train.size(0) assert patches_train.size(1) == images_train.size(1) assert patches_train.size(2) == 9 assert patches_test.dim() == 6 assert patches_test.size(0) == images_test.size(0) assert patches_test.size(1) == images_test.size(1) assert patches_test.size(2) == 9 meta_batch_size = images_train.size(0) num_train = images_train.size(1) num_test = images_test.size(1) if is_train: # Zero gradients. feature_extractor_optimizer.zero_grad() classifier_optimizer.zero_grad() if patch_location_loss_coef > 0.0: location_classifier_optimizer.zero_grad() if patch_classification_loss_coef > 0.0: patch_classifier_optimizer.zero_grad() record = {} with torch.no_grad(): images_train = utils.convert_from_5d_to_4d(images_train) images_test = utils.convert_from_5d_to_4d(images_test) labels_test = labels_test.view(-1) images = torch.cat([images_train, images_test], dim=0) batch_size_train = images_train.size(0) batch_size_train_test = images.size(0) assert batch_size_train == meta_batch_size * num_train assert batch_size_train_test == meta_batch_size * (num_train + num_test) patches_train = utils.convert_from_6d_to_4d(patches_train) patches_test = utils.convert_from_6d_to_4d(patches_test) if standardize_patches: patches_train = utils.standardize_image(patches_train) patches_test = utils.standardize_image(patches_test) patches = torch.cat([patches_train, patches_test], dim=0) assert patches_train.size(0) == batch_size_train * 9 assert patches.size(0) == batch_size_train_test * 9 with torch.set_grad_enabled(is_train): # Extract features from the images. features = feature_extractor(images) # Extract features from the image patches. features_patches = feature_extractor(patches) # Perform object classification task. features_train = features[:batch_size_train] features_test = features[batch_size_train:batch_size_train_test] features_train = utils.add_dimension(features_train, meta_batch_size) features_test = utils.add_dimension(features_test, meta_batch_size) ( classification_scores, loss_classsification, ) = fewshot_utils.few_shot_feature_classification( classifier, features_test, features_train, labels_train_1hot, labels_test, base_ids, ) record["loss_cls"] = loss_classsification.item() loss_total = loss_classsification # Perform the self-supervised task of relative patch locatioon # prediction. if patch_location_loss_coef > 0.0: scores_location, loss_location, labels_loc = patch_location_task( location_classifier, features_patches) record["loss_loc"] = loss_location.item() loss_total = loss_total + loss_location * patch_location_loss_coef # Perform the auxiliary task of classifying patches. if patch_classification_loss_coef > 0.0: features_patches = add_patch_dimension(features_patches) assert features_patches.size(0) == batch_size_train_test assert features_patches.size(1) == 9 features_patches = combine_multiple_patch_features( features_patches, combine) features_patches_train = utils.add_dimension( features_patches[:batch_size_train], meta_batch_size) features_patches_test = utils.add_dimension( features_patches[batch_size_train:batch_size_train_test], meta_batch_size) scores_patch, loss_patch = fewshot_utils.few_shot_feature_classification( patch_classifier, features_patches_test, features_patches_train, labels_train_1hot, labels_test, base_ids, ) record["loss_patch_cls"] = loss_patch.item() loss_total = loss_total + loss_patch * patch_classification_loss_coef # Because the total loss consists of multiple individual losses # (i.e., 3) scale it down by a factor of 0.5. loss_total = loss_total * 0.5 with torch.no_grad(): num_base = base_ids.size(1) if (base_ids is not None) else 0 record = fewshot_utils.compute_accuracy_metrics( classification_scores, labels_test, num_base, record) if patch_location_loss_coef > 0.0: record["AccuracyLoc"] = utils.top1accuracy(scores_location, labels_loc) if patch_classification_loss_coef > 0.0: record = fewshot_utils.compute_accuracy_metrics(scores_patch, labels_test, num_base, record, string_id="Patch") if is_train: # Backward loss and apply gradient steps. loss_total.backward() feature_extractor_optimizer.step() classifier_optimizer.step() if patch_location_loss_coef > 0.0: location_classifier_optimizer.step() if patch_classification_loss_coef > 0.0: patch_classifier_optimizer.step() return record
def object_classification_with_patch_location_selfsupervision( feature_extractor, feature_extractor_optimizer, classifier, classifier_optimizer, location_classifier, location_classifier_optimizer, patch_classifier, patch_classifier_optimizer, images, labels, patches, labels_patches, is_train, patch_location_loss_coef=1.0, patch_classification_loss_coef=1.0, combine="average", base_ids=None, standardize_patches=True, ): """Forward-backward propagation routine for classification model extended with the auxiliary self-supervised task of predicting the relative location of patches.""" if base_ids is not None: assert base_ids.size(0) == 1 assert images.dim() == 4 assert labels.dim() == 1 assert images.size(0) == labels.size(0) assert patches.dim() == 5 and patches.size(1) == 9 assert patches.size(0) == labels_patches.size(0) patches = utils.convert_from_5d_to_4d(patches) if standardize_patches: patches = utils.standardize_image(patches) if is_train: # Zero gradients. feature_extractor_optimizer.zero_grad() classifier_optimizer.zero_grad() if patch_location_loss_coef > 0.0: location_classifier_optimizer.zero_grad() if patch_classification_loss_coef > 0.0: patch_classifier_optimizer.zero_grad() record = {} with torch.set_grad_enabled(is_train): # Extract features from the images. features_images = feature_extractor(images) # Extract features from the image patches. features_patches = feature_extractor(patches) # Perform object classification task. scores_classification, loss_classsification = cls_utils.classification_task( classifier, features_images, labels, base_ids) record["loss_cls"] = loss_classsification.item() loss_total = loss_classsification # Perform the self-supervised task of relative patch locatioon # prediction. if patch_location_loss_coef > 0.0: scores_location, loss_location, labels_loc = patch_location_task( location_classifier, features_patches) record["loss_loc"] = loss_location.item() loss_total = loss_total + loss_location * patch_location_loss_coef # Perform the auxiliary task of classifying individual patches. if patch_classification_loss_coef > 0.0: scores_patch, loss_patch = patch_classification_task( patch_classifier, features_patches, labels_patches, combine) record["loss_patch_cls"] = loss_patch.item() loss_total = loss_total + loss_patch * patch_classification_loss_coef # Because the total loss consists of multiple individual losses # (i.e., 3) scale it down by a factor of 0.5. loss_total = loss_total * 0.5 with torch.no_grad(): # Compute accuracies. record["Accuracy"] = utils.top1accuracy(scores_classification, labels) if patch_location_loss_coef > 0.0: record["AccuracyLoc"] = utils.top1accuracy(scores_location, labels_loc) if patch_classification_loss_coef > 0.0: record["AccuracyPatch"] = utils.top1accuracy( scores_patch, labels_patches) if is_train: # Backward loss and apply gradient steps. loss_total.backward() feature_extractor_optimizer.step() classifier_optimizer.step() if patch_location_loss_coef > 0.0: location_classifier_optimizer.step() if patch_classification_loss_coef > 0.0: patch_classifier_optimizer.step() return record
def fewshot_classification_with_rotation_selfsupervision( feature_extractor, feature_extractor_optimizer, classifier, classifier_optimizer, classifier_rot, classifier_rot_optimizer, images_train, labels_train, labels_train_1hot, images_test, labels_test, is_train, alpha=1.0, base_ids=None, feature_name=None, ): """Forward-backward routine of a few-shot model with auxiliary rotation prediction task. Given as input a mini-batch of few-shot episodes, it applies the forward and (optionally) backward propagation routines of the few-shot classification task. Each episode consists of (1) num_train_examples number of training examples for the novel classes of the few-shot episode, (2) the labels of training examples of the novel classes, (3) num_test_examples number of test examples of the few-shot episode (note that the test examples can be from both base and novel classes), and (4) the labels of the test examples. Each mini-batch consists of meta_batch_size number of few-shot episodes. The code assumes that the few-shot classification model is divided into a feature extractor network and a classification head network. Also, the code applies the auxiliary self-supervised task of predicting image rotations. The rotation prediction task is applied on both the test and train examples of the few-shot episodes. Args: feature_extractor: The feature extractor neural network. feature_extractor_optimizer: The parameter optimizer of the feature extractor. If None, then the feature extractor remains frozen during training. classifier: The classification head applied on the output of the feature extractor. classifier_optimizer: The parameter optimizer of the classification head. classifier_rot: The rotation prediction head applied on the output of the feature extractor. classifier_rot_optimizer: The parameter optimizer of the rotation prediction head. images_train: A 5D tensor with shape [meta_batch_size x num_train_examples x channels x height x width] that represents a mini-batch of meta_batch_size number of few-shot episodes, each with num_train_examples number of training examples. labels_train: A 2D tensor with shape [meta_batch_size x num_train_examples] that represents the discrete labels of the training examples of each few-shot episode in the batch. labels_train_1hot: A 3D tensor with shape [meta_batch_size x num_train_examples x num_novel] that represents the 1hot labels of the training examples of the novel classes of each few-shot episode in the batch. num_novel is the number of novel classes per few-shot episode. images_test: A 5D tensor with shape [meta_batch_size x num_test_examples x channels x height x width] that represents a mini-batch of meta_batch_size number of few-shot episodes, each with num_test_examples number of test examples. labels_test: A 2D tensor with shape [meta_batch_size x num_test_examples] that represents the discrete labels of the test examples of each few-shot episode in the mini-batch. is_train: Boolean value that indicates if this mini-batch will be used for training or testing. If is_train is False, then the code does not apply the backward propagation step and does not update the parameter optimizers. base_ids: A 2D tensor with shape [meta_batch_size x num_base], where base_ids[m] are the indices of the base categories that are being used in the m-th few-shot episode. num_base is the number of base classes per few-shot episode. alpha: (optional) The loss weight of the rotation prediction task. feature_name: (optional) A string or list of strings with the name of feature level(s) from which the feature extractor will extract features for the classification task. Returns: record: A dictionary of scalar values with the following items: 'loss_cls': The cross entropy loss of the few-shot classification task. 'loss_rot': The rotation prediction loss. 'loss_total': The total loss, i.e., loss_cls + alpha * loss_rot. 'AccuracyNovel': The classification accuracy of the test examples among only the novel classes. 'AccuracyBase': (optinional) The classification accuracy of the test examples among only the base classes. Applicable, only if there are test examples from base classes in the mini-batch. 'AccuracyBase': (optinional) The classification accuracy of the test examples among both the base and novel classes. Applicable, only if there are test examples from base classes in the mini-batch. 'AccuracyRot': The accuracy of the rotation prediction task. """ assert images_train.dim() == 5 assert images_test.dim() == 5 assert images_train.size(0) == images_test.size(0) assert images_train.size(2) == images_test.size(2) assert images_train.size(3) == images_test.size(3) assert images_train.size(4) == images_test.size(4) assert labels_train.dim() == 2 assert labels_test.dim() == 2 assert labels_train.size(0) == labels_test.size(0) assert labels_train.size(0) == images_train.size(0) meta_batch_size = images_train.size(0) num_train = images_train.size(1) num_test = images_test.size(1) if is_train: # zero the gradients feature_extractor_optimizer.zero_grad() classifier_optimizer.zero_grad() classifier_rot_optimizer.zero_grad() record = {} with torch.no_grad(): images_train = utils.convert_from_5d_to_4d(images_train) images_test = utils.convert_from_5d_to_4d(images_test) labels_test = labels_test.view(-1) images = torch.cat([images_train, images_test], dim=0) batch_size_train = images_train.size(0) batch_size_train_test = images.size(0) assert batch_size_train == meta_batch_size * num_train assert batch_size_train_test == meta_batch_size * (num_train + num_test) # Create the 4 rotated version of the images; this step increases # the batch size by a multiple of 4. images = create_4rotations_images(images) labels_rotation = create_rotations_labels(batch_size_train_test, images.device) with torch.set_grad_enabled(is_train): # Extract features from the train and test images. features = cls_utils.extract_features(feature_extractor, images, feature_name=feature_name) # Apply the few-shot classification head. features_train = features[:batch_size_train] features_test = features[batch_size_train:batch_size_train_test] features_train = utils.add_dimension(features_train, meta_batch_size) features_test = utils.add_dimension(features_test, meta_batch_size) ( classification_scores, loss_classsification, ) = fewshot_utils.few_shot_feature_classification( classifier, features_test, features_train, labels_train_1hot, labels_test, base_ids, ) record["loss_cls"] = loss_classsification.item() # Apply the rotation prediction head. scores_rotation, loss_rotation = rotation_task(classifier_rot, features, labels_rotation) record["loss_rot"] = loss_rotation.item() # Compute total loss. loss_total = loss_classsification + alpha * loss_rotation record["loss_total"] = loss_total.item() with torch.no_grad(): num_base = base_ids.size(1) if (base_ids is not None) else 0 record = fewshot_utils.compute_accuracy_metrics( classification_scores, labels_test, num_base, record) record["AccuracyRot"] = utils.top1accuracy(scores_rotation, labels_rotation) if is_train: loss_total.backward() feature_extractor_optimizer.step() classifier_optimizer.step() classifier_rot_optimizer.step() return record
def fewshot_classification( feature_extractor, feature_extractor_optimizer, classifier, classifier_optimizer, images_train, labels_train, labels_train_1hot, images_test, labels_test, is_train, base_ids=None, feature_name=None, classification_coef=1.0, ): """Forward-backward propagation routine of the few-shot classification task. Given as input a mini-batch of few-shot episodes, it applies the forward and (optionally) backward propagation routines of the few-shot classification task. Each episode consists of (1) num_train_examples number of training examples for the novel classes of the few-shot episode, (2) the labels of training examples of the novel classes, (3) num_test_examples number of test examples of the few-shot episode (note that the test examples can be from both base and novel classes), and (4) the labels of the test examples. Each mini-batch consists of meta_batch_size number of few-shot episodes. The code assumes that the few-shot classification model is divided into a feature extractor network and a classification head network. Args: feature_extractor: The feature extractor neural network. feature_extractor_optimizer: The parameter optimizer of the feature extractor. If None, then the feature extractor remains frozen during training. classifier: The classification head applied on the output of the feature extractor. classifier_optimizer: The parameter optimizer of the classification head. images_train: A 5D tensor with shape [meta_batch_size x num_train_examples x channels x height x width] that represents a mini-batch of meta_batch_size number of few-shot episodes, each with num_train_examples number of training examples. labels_train: A 2D tensor with shape [meta_batch_size x num_train_examples] that represents the discrete labels of the training examples of each few-shot episode in the batch. labels_train_1hot: A 3D tensor with shape [meta_batch_size x num_train_examples x num_novel] that represents the 1hot labels of the training examples of the novel classes of each few-shot episode in the batch. num_novel is the number of novel classes per few-shot episode. images_test: A 5D tensor with shape [meta_batch_size x num_test_examples x channels x height x width] that represents a mini-batch of meta_batch_size number of few-shot episodes, each with num_test_examples number of test examples. labels_test: A 2D tensor with shape [meta_batch_size x num_test_examples] that represents the discrete labels of the test examples of each few-shot episode in the mini-batch. is_train: Boolean value that indicates if this mini-batch will be used for training or testing. If is_train is False, then the code does not apply the backward propagation step and does not update the parameter optimizers. base_ids: A 2D tensor with shape [meta_batch_size x num_base], where base_ids[m] are the indices of the base categories that are being used in the m-th few-shot episode. num_base is the number of base classes per few-shot episode. feature_name: (optional) A string or list of strings with the name of feature level(s) from which the feature extractor will extract features for the classification task. classification_coef: (optional) the loss weight of the few-shot classification task. Returns: record: A dictionary of scalar values with the following items: 'loss': The cross entropy loss of this mini-batch. 'AccuracyNovel': The classification accuracy of the test examples among only the novel classes. 'AccuracyBase': (optinional) The classification accuracy of the test examples among only the base classes. Applicable, only if there are test examples from base classes in the mini-batch. 'AccuracyBase': (optinional) The classification accuracy of the test examples among both the base and novel classes. Applicable, only if there are test examples from base classes in the mini-batch. """ assert images_train.dim() == 5 assert images_test.dim() == 5 assert images_train.size(0) == images_test.size(0) assert images_train.size(2) == images_test.size(2) assert images_train.size(3) == images_test.size(3) assert images_train.size(4) == images_test.size(4) assert labels_train.dim() == 2 assert labels_test.dim() == 2 assert labels_train.size(0) == labels_test.size(0) assert labels_train.size(0) == images_train.size(0) assert not (isinstance(feature_name, (list, tuple)) and len(feature_name) > 1) meta_batch_size = images_train.size(0) if is_train: # zero the gradients if feature_extractor_optimizer: feature_extractor_optimizer.zero_grad() classifier_optimizer.zero_grad() record = {} with torch.no_grad(): images_train = utils.convert_from_5d_to_4d(images_train) images_test = utils.convert_from_5d_to_4d(images_test) labels_test = labels_test.view(-1) batch_size_train = images_train.size(0) # batch_size_test = images_test.size(0) images = torch.cat([images_train, images_test], dim=0) train_feature_extractor = is_train and (feature_extractor_optimizer is not None) with torch.set_grad_enabled(train_feature_extractor): # Extract features from the train and test images. features = cls_utils.extract_features(feature_extractor, images, feature_name=feature_name) if not train_feature_extractor: # Make sure that no gradients are backproagated to the feature # extractor when the feature extraction model is freezed. features = features.detach() with torch.set_grad_enabled(is_train): features_train = features[:batch_size_train] features_test = features[batch_size_train:] features_train = utils.add_dimension(features_train, meta_batch_size) features_test = utils.add_dimension(features_test, meta_batch_size) # Apply the classification head of the few-shot classification model. classification_scores, loss = few_shot_feature_classification( classifier, features_test, features_train, labels_train_1hot, labels_test, base_ids, ) record["loss"] = loss.item() loss_total = loss * classification_coef # ******************************************************************* with torch.no_grad(): num_base = base_ids.size(1) if (base_ids is not None) else 0 record = compute_accuracy_metrics(classification_scores, labels_test, num_base, record) if is_train: loss_total.backward() if feature_extractor_optimizer: feature_extractor_optimizer.step() classifier_optimizer.step() return record
def process_batch_fewshot_classification_task(self, is_train): Kids = self.tensors["Kids"] base_ids = None if (self.num_base == 0) else Kids[:, :self.num_base].contiguous() images_train = self.tensors["images_train"] images_test = self.tensors["images_test"] if self.standardize_image: assert images_train.dim() == 5 and images_test.dim() == 5 assert images_train.size(0) == images_test.size(0) meta_batch_size = images_train.size(0) images_train = utils.convert_from_5d_to_4d(images_train) images_test = utils.convert_from_5d_to_4d(images_test) images_train = utils.standardize_image(images_train) images_test = utils.standardize_image(images_test) images_train = utils.add_dimension(images_train, meta_batch_size) images_test = utils.add_dimension(images_test, meta_batch_size) auxiliary_tasks = is_train and ( self.patch_location_loss_coef > 0.0 or self.patch_classification_loss_coef > 0.0) if auxiliary_tasks: record = loc_utils.fewshot_classification_with_patch_location_selfsupervision( feature_extractor=self.networks["feature_extractor"], feature_extractor_optimizer=self. optimizers["feature_extractor"], classifier=self.networks["classifier"], classifier_optimizer=self.optimizers["classifier"], location_classifier=self.networks.get("classifier_loc"), location_classifier_optimizer=self.optimizers.get( "classifier_loc"), patch_classifier=self.networks.get("patch_classifier"), patch_classifier_optimizer=self.optimizers.get( "patch_classifier"), images_train=images_train, patches_train=self.tensors["patches_train"], labels_train=self.tensors["labels_train"], labels_train_1hot=self.tensors["labels_train_1hot"], images_test=images_test, patches_test=self.tensors["patches_test"], labels_test=self.tensors["labels_test"], is_train=is_train, base_ids=base_ids, patch_location_loss_coef=self.patch_location_loss_coef, patch_classification_loss_coef=self. patch_classification_loss_coef, combine=self.combine_patches, standardize_patches=self.standardize_patches, ) else: record = fs_utils.fewshot_classification( feature_extractor=self.networks["feature_extractor"], feature_extractor_optimizer=self.optimizers.get( "feature_extractor"), classifier=self.networks["classifier"], classifier_optimizer=self.optimizers.get("classifier"), images_train=images_train, labels_train=self.tensors["labels_train"], labels_train_1hot=self.tensors["labels_train_1hot"], images_test=images_test, labels_test=self.tensors["labels_test"], is_train=is_train, base_ids=base_ids, ) if not is_train: record, self.accuracies = fs_utils.compute_95confidence_intervals( record, episode=self.biter, num_episodes=self.bnumber, store_accuracies=self.accuracies, metrics=[ "AccuracyNovel", ], ) return record