def main() -> None: """ This function checks metric learning pipeline with different triplets samplers. """ cmc_score_th = 0.9 # Note! cmc_score should be > 0.97 # after 600 epoch. Please check it mannually # to avoid wasting time of CI pod all_sampler = data.AllTripletsSampler(max_output_triplets=512) hard_sampler = data.HardTripletsSampler(norm_required=False) assert run_ml_pipeline(all_sampler) > cmc_score_th assert run_ml_pipeline(hard_sampler) > cmc_score_th
def test_reid_pipeline(): """This test checks that reid pipeline runs and compute metrics with ReidCMCScoreCallback""" with TemporaryDirectory() as logdir: # 1. train and valid loaders transforms = Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))]) train_dataset = MnistMLDataset(root=os.getcwd(), download=True, transform=transforms) sampler = data.BatchBalanceClassSampler( labels=train_dataset.get_labels(), num_classes=3, num_samples=10, num_batches=20) train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler, num_workers=0) valid_dataset = MnistReIDQGDataset(root=os.getcwd(), transform=transforms, gallery_fraq=0.2) valid_loader = DataLoader(dataset=valid_dataset, batch_size=1024) # 2. model and optimizer model = models.MnistSimpleNet(out_features=16) optimizer = Adam(model.parameters(), lr=0.001) # 3. criterion with triplets sampling sampler_inbatch = data.AllTripletsSampler(max_output_triplets=1000) criterion = nn.TripletMarginLossWithSampler( margin=0.5, sampler_inbatch=sampler_inbatch) # 4. training with catalyst Runner callbacks = [ dl.ControlFlowCallback( dl.CriterionCallback(input_key="embeddings", target_key="targets", metric_key="loss"), loaders="train", ), dl.ControlFlowCallback( dl.ReidCMCScoreCallback( embeddings_key="embeddings", pids_key="targets", cids_key="cids", is_query_key="is_query", topk_args=[1], ), loaders="valid", ), dl.PeriodicLoaderCallback(valid_loader_key="valid", valid_metric_key="cmc01", minimize=False, valid=2), ] runner = ReIDCustomRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, callbacks=callbacks, loaders=OrderedDict({ "train": train_loader, "valid": valid_loader }), verbose=False, logdir=logdir, valid_loader="valid", valid_metric="cmc01", minimize_valid_metric=False, num_epochs=10, ) assert "cmc01" in runner.loader_metrics assert runner.loader_metrics["cmc01"] > 0.7