def test_mperclass_sampler_with_batch_size(self): for batch_size in [4, 50, 99, 100, 1024]: for m in [1, 5, 10, 17, 50]: for num_labels in [2, 10, 55]: for length_before_new_iter in [100, 999, 10000]: labels = torch.randint(low=0, high=num_labels, size=(10000, )) args = [labels, m, batch_size, length_before_new_iter] if ((length_before_new_iter < batch_size) or (m * num_labels < batch_size) or (batch_size % m != 0)): self.assertRaises(AssertionError, MPerClassSampler, *args) continue else: sampler = MPerClassSampler(*args) iterator = iter(sampler) for _ in range(1000): x = [] for _ in range(batch_size): iterator, curr_batch = c_f.try_next_on_generator( iterator, sampler) x.append(curr_batch) curr_labels = labels[x] unique_labels, counts = torch.unique( curr_labels, return_counts=True) self.assertTrue( len(unique_labels) == batch_size // m) self.assertTrue(torch.all(counts == m))
def test_mperclass_sampler(self): batch_size = 100 m = 5 length_before_new_iter = 9999 num_labels = 100 labels = torch.randint(low=0, high=num_labels, size=(10000,)) sampler = MPerClassSampler(labels=labels, m=m, length_before_new_iter=length_before_new_iter) self.assertTrue(len(sampler) == (m*num_labels)*(length_before_new_iter // (m*num_labels))) iterable = iter(sampler) for _ in range(10): x = [next(iterable) for _ in range(batch_size)] curr_labels = labels[x] unique_labels, counts = torch.unique(curr_labels, return_counts=True) self.assertTrue(len(unique_labels) == batch_size // m) self.assertTrue(torch.all(counts==m))
def test_metric_loss_only(self): cifar_resnet_folder = "temp_cifar_resnet_for_pytorch_metric_learning_test" dataset_folder = "temp_dataset_for_pytorch_metric_learning_test" model_folder = "temp_saved_models_for_pytorch_metric_learning_test" logs_folder = "temp_logs_for_pytorch_metric_learning_test" tensorboard_folder = "temp_tensorboard_for_pytorch_metric_learning_test" os.system( "git clone https://github.com/akamaster/pytorch_resnet_cifar10.git {}" .format(cifar_resnet_folder)) loss_fn = NTXentLoss() normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize_transform, ]) eval_transform = transforms.Compose( [transforms.ToTensor(), normalize_transform]) assert not os.path.isdir(dataset_folder) assert not os.path.isdir(model_folder) assert not os.path.isdir(logs_folder) assert not os.path.isdir(tensorboard_folder) subset_idx = np.arange(10000) train_dataset = datasets.CIFAR100(dataset_folder, train=True, download=True, transform=train_transform) train_dataset_for_eval = datasets.CIFAR100(dataset_folder, train=True, download=True, transform=eval_transform) val_dataset = datasets.CIFAR100(dataset_folder, train=False, download=True, transform=eval_transform) train_dataset = torch.utils.data.Subset(train_dataset, subset_idx) train_dataset_for_eval = torch.utils.data.Subset( train_dataset_for_eval, subset_idx) val_dataset = torch.utils.data.Subset(val_dataset, subset_idx) for dtype in TEST_DTYPES: for splits_to_eval in [ None, [("train", ["train", "val"]), ("val", ["train", "val"])], ]: from temp_cifar_resnet_for_pytorch_metric_learning_test import resnet model = torch.nn.DataParallel(resnet.resnet20()) checkpoint = torch.load( "{}/pretrained_models/resnet20-12fca82f.th".format( cifar_resnet_folder), map_location=TEST_DEVICE, ) model.load_state_dict(checkpoint["state_dict"]) model.module.linear = c_f.Identity() if TEST_DEVICE == torch.device("cpu"): model = model.module model = model.to(TEST_DEVICE).type(dtype) optimizer = torch.optim.Adam( model.parameters(), lr=0.0002, weight_decay=0.0001, eps=1e-04, ) batch_size = 32 iterations_per_epoch = None if splits_to_eval is None else 1 model_dict = {"trunk": model} optimizer_dict = {"trunk_optimizer": optimizer} loss_fn_dict = {"metric_loss": loss_fn} sampler = MPerClassSampler( np.array(train_dataset.dataset.targets)[subset_idx], m=4, batch_size=32, length_before_new_iter=len(train_dataset), ) record_keeper, _, _ = logging_presets.get_record_keeper( logs_folder, tensorboard_folder) hooks = logging_presets.get_hook_container( record_keeper, primary_metric="precision_at_1") dataset_dict = { "train": train_dataset_for_eval, "val": val_dataset } tester = GlobalEmbeddingSpaceTester( end_of_testing_hook=hooks.end_of_testing_hook, accuracy_calculator=accuracy_calculator.AccuracyCalculator( include=("precision_at_1", "AMI"), k=1), data_device=TEST_DEVICE, dtype=dtype, dataloader_num_workers=32, ) end_of_epoch_hook = hooks.end_of_epoch_hook( tester, dataset_dict, model_folder, test_interval=1, patience=1, splits_to_eval=splits_to_eval, ) trainer = MetricLossOnly( models=model_dict, optimizers=optimizer_dict, batch_size=batch_size, loss_funcs=loss_fn_dict, mining_funcs={}, dataset=train_dataset, sampler=sampler, data_device=TEST_DEVICE, dtype=dtype, dataloader_num_workers=32, iterations_per_epoch=iterations_per_epoch, freeze_trunk_batchnorm=True, end_of_iteration_hook=hooks.end_of_iteration_hook, end_of_epoch_hook=end_of_epoch_hook, ) num_epochs = 3 trainer.train(num_epochs=num_epochs) best_epoch, best_accuracy = hooks.get_best_epoch_and_accuracy( tester, "val") if splits_to_eval is None: self.assertTrue(best_epoch == 3) self.assertTrue(best_accuracy > 0.2) accuracies, primary_metric_key = hooks.get_accuracies_of_best_epoch( tester, "val") accuracies = c_f.sqliteObjToDict(accuracies) self.assertTrue( accuracies[primary_metric_key][0] == best_accuracy) self.assertTrue(primary_metric_key == "precision_at_1_level0") best_epoch_accuracies = hooks.get_accuracies_of_epoch( tester, "val", best_epoch) best_epoch_accuracies = c_f.sqliteObjToDict( best_epoch_accuracies) self.assertTrue(best_epoch_accuracies[primary_metric_key][0] == best_accuracy) accuracy_history = hooks.get_accuracy_history(tester, "val") self.assertTrue(accuracy_history[primary_metric_key][ accuracy_history["epoch"].index(best_epoch)] == best_accuracy) loss_history = hooks.get_loss_history() if splits_to_eval is None: self.assertTrue( len(loss_history["metric_loss"]) == (len(sampler) / batch_size) * num_epochs) curr_primary_metric = hooks.get_curr_primary_metric( tester, "val") self.assertTrue(curr_primary_metric == accuracy_history[primary_metric_key][-1]) base_record_group_name = hooks.base_record_group_name(tester) self.assertTrue( base_record_group_name == "accuracies_normalized_GlobalEmbeddingSpaceTester_level_0") record_group_name = hooks.record_group_name(tester, "val") if splits_to_eval is None: self.assertTrue( record_group_name == "accuracies_normalized_GlobalEmbeddingSpaceTester_level_0_VAL_vs_self" ) else: self.assertTrue( record_group_name == "accuracies_normalized_GlobalEmbeddingSpaceTester_level_0_VAL_vs_TRAIN_and_VAL" ) shutil.rmtree(model_folder) shutil.rmtree(logs_folder) shutil.rmtree(tensorboard_folder) shutil.rmtree(cifar_resnet_folder) shutil.rmtree(dataset_folder)