Example #1
0
    def test_equivalence(self, compute_inv):
        """Make sure moving Lambda layers does not affect the results."""
        losses = ("mse", "mse", "mse") if compute_inv else ("mse", "mse")
        losses_weights = (1, 1, 1) if compute_inv else (1, 1)
        params = {
            "start_filters": (2, ),
            "downsample_filters": (2, 3),
            "middle_filters": (2, ),
            "upsample_filters": (2, 3),
            "end_filters": tuple(),
            "compute_inv": compute_inv,
            "losses": losses,
            "losses_weights": losses_weights,
        }

        np.random.seed(1337)
        model_with = supervised_model_factory(use_lambda=True, **params)
        np.random.seed(1337)
        model_without = supervised_model_factory(use_lambda=False, **params)
        x = np.random.random((1, 320, 456, 2))
        pred_with = model_with.predict([x, x] if compute_inv else x)
        pred_without = model_without.predict([x, x] if compute_inv else x)

        assert np.allclose(pred_with[0], pred_without[0])
        assert np.allclose(pred_with[1], pred_without[1])
        if compute_inv:
            assert np.allclose(pred_with[2], pred_without[2])
Example #2
0
 def test_down_up_samples(self):
     """Make sure raises an error if downsamples and upsamples have not the same number of layers"""
     with pytest.raises(ValueError):
         supervised_model_factory(downsample_filters=(2, ),
                                  upsample_filters=(2, 3))
     with pytest.raises(ValueError):
         supervised_model_factory(
             downsample_filters=(2, 2, 2, 2, 2, 2, 2),
             upsample_filters=(2, 2, 2, 2, 2, 2, 2),
         )
Example #3
0
    def test_use_lambda(self, use_lambda, compute_inv):
        """Make sure the `use_lambda` flag is working"""

        losses = ("mse", "mse", "mse") if compute_inv else ("mse", "mse")
        losses_weights = (1, 1, 1) if compute_inv else (1, 1)
        model = supervised_model_factory(
            losses=losses,
            losses_weights=losses_weights,
            compute_inv=compute_inv,
            use_lambda=use_lambda,
        )
        lambda_list = [
            x for x in model.layers if isinstance(x, keras.layers.Lambda)
        ]
        if use_lambda:
            assert lambda_list
        else:
            assert not lambda_list
Example #4
0
    def test_default_construction(self):
        """Make sure possible to use with the default setting"""

        model = supervised_model_factory()
        assert isinstance(model, keras.Model)
    def test_compute_external_metrics(
        self, monkeypatch, tmpdir, random_state, return_inverse
    ):
        evaluate_cache = []

        def fake_evaluate(*args, **kwargs):
            evaluate_cache.append(
                {
                    "deltas_true": args[0],
                    "img_mov": args[2],
                    "p": kwargs["p"],
                    "deltas_true_inv": kwargs["deltas_true_inv"],
                }
            )

            return pd.Series([2, 3])

        monkeypatch.setattr(
            "atlalign.ml_utils.callbacks.evaluate_single",
            Mock(side_effect=fake_evaluate),
        )
        monkeypatch.setattr("atlalign.ml_utils.callbacks.annotation_volume", Mock())
        monkeypatch.setattr(
            "atlalign.ml_utils.io.nissl_volume",
            Mock(return_value=np.zeros((528, 320, 456, 1))),
        )
        monkeypatch.setattr(
            "atlalign.ml_utils.callbacks.segmentation_collapsing_labels", Mock()
        )

        n_samples = 10
        n_val_samples = 4
        h5_path = pathlib.Path(str(tmpdir)) / "temp.h5"
        self.create_h5(h5_path, n_samples, random_state)

        val_indexes = list(np.random.choice(n_samples, n_val_samples, replace=False))

        val_gen = SupervisedGenerator(
            h5_path,
            indexes=val_indexes,
            shuffle=False,
            batch_size=1,
            return_inverse=return_inverse,
        )
        losses = ["mse", "mse", "mse"] if return_inverse else ["mse", "mse"]
        losses_weights = [1, 1, 1] if return_inverse else [1, 1]

        model = supervised_model_factory(
            compute_inv=return_inverse,
            losses=losses,
            losses_weights=losses_weights,
            start_filters=(2,),
            downsample_filters=(4, 2),
            middle_filters=(2,),
            upsample_filters=(2, 4),
        )

        df = MLFlowCallback.compute_external_metrics(model, val_gen)

        assert len(df) == len(val_indexes)
        assert np.allclose(
            df.index.values, load_dataset_in_memory(h5_path, "image_id")[val_indexes]
        )
        assert len(evaluate_cache) == len(val_indexes)

        for ecache, val_index in zip(evaluate_cache, val_indexes):
            expected_deltas = load_dataset_in_memory(h5_path, "deltas_xy")[val_index]
            expected_deltas_inv = load_dataset_in_memory(h5_path, "inv_deltas_xy")[
                val_index
            ]
            expected_image = load_dataset_in_memory(h5_path, "img")[val_index] / 255
            expected_p = load_dataset_in_memory(h5_path, "p")[val_index]

            assert np.allclose(expected_deltas, ecache["deltas_true"])
            assert np.allclose(expected_image, ecache["img_mov"])
            assert np.allclose(expected_p, ecache["p"])

            if return_inverse:
                assert np.allclose(expected_deltas_inv, ecache["deltas_true_inv"])
            else:
                assert ecache["deltas_true_inv"] is None  # they are not streamed