Beispiel #1
0
def main() -> None:
    """
    This function checks metric learning pipeline with
    different triplets samplers.
    """
    cmc_score_th = 0.9

    # Note! cmc_score should be > 0.97
    # after 600 epoch. Please check it mannually
    # to avoid wasting time of CI pod

    all_sampler = data.AllTripletsSampler(max_output_triplets=512)
    hard_sampler = data.HardTripletsSampler(norm_required=False)

    assert run_ml_pipeline(all_sampler) > cmc_score_th
    assert run_ml_pipeline(hard_sampler) > cmc_score_th
Beispiel #2
0
def test_reid_pipeline():
    """This test checks that reid pipeline runs and compute metrics with ReidCMCScoreCallback"""
    with TemporaryDirectory() as logdir:

        # 1. train and valid loaders
        transforms = Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))])

        train_dataset = MnistMLDataset(root=os.getcwd(),
                                       download=True,
                                       transform=transforms)
        sampler = data.BatchBalanceClassSampler(
            labels=train_dataset.get_labels(),
            num_classes=3,
            num_samples=10,
            num_batches=20)
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_sampler=sampler,
                                  num_workers=0)

        valid_dataset = MnistReIDQGDataset(root=os.getcwd(),
                                           transform=transforms,
                                           gallery_fraq=0.2)
        valid_loader = DataLoader(dataset=valid_dataset, batch_size=1024)

        # 2. model and optimizer
        model = models.MnistSimpleNet(out_features=16)
        optimizer = Adam(model.parameters(), lr=0.001)

        # 3. criterion with triplets sampling
        sampler_inbatch = data.AllTripletsSampler(max_output_triplets=1000)
        criterion = nn.TripletMarginLossWithSampler(
            margin=0.5, sampler_inbatch=sampler_inbatch)

        # 4. training with catalyst Runner
        callbacks = [
            dl.ControlFlowCallback(
                dl.CriterionCallback(input_key="embeddings",
                                     target_key="targets",
                                     metric_key="loss"),
                loaders="train",
            ),
            dl.ControlFlowCallback(
                dl.ReidCMCScoreCallback(
                    embeddings_key="embeddings",
                    pids_key="targets",
                    cids_key="cids",
                    is_query_key="is_query",
                    topk_args=[1],
                ),
                loaders="valid",
            ),
            dl.PeriodicLoaderCallback(valid_loader_key="valid",
                                      valid_metric_key="cmc01",
                                      minimize=False,
                                      valid=2),
        ]

        runner = ReIDCustomRunner()
        runner.train(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            callbacks=callbacks,
            loaders=OrderedDict({
                "train": train_loader,
                "valid": valid_loader
            }),
            verbose=False,
            logdir=logdir,
            valid_loader="valid",
            valid_metric="cmc01",
            minimize_valid_metric=False,
            num_epochs=10,
        )
        assert "cmc01" in runner.loader_metrics
        assert runner.loader_metrics["cmc01"] > 0.7