def test_equivalence(self, compute_inv): """Make sure moving Lambda layers does not affect the results.""" losses = ("mse", "mse", "mse") if compute_inv else ("mse", "mse") losses_weights = (1, 1, 1) if compute_inv else (1, 1) params = { "start_filters": (2, ), "downsample_filters": (2, 3), "middle_filters": (2, ), "upsample_filters": (2, 3), "end_filters": tuple(), "compute_inv": compute_inv, "losses": losses, "losses_weights": losses_weights, } np.random.seed(1337) model_with = supervised_model_factory(use_lambda=True, **params) np.random.seed(1337) model_without = supervised_model_factory(use_lambda=False, **params) x = np.random.random((1, 320, 456, 2)) pred_with = model_with.predict([x, x] if compute_inv else x) pred_without = model_without.predict([x, x] if compute_inv else x) assert np.allclose(pred_with[0], pred_without[0]) assert np.allclose(pred_with[1], pred_without[1]) if compute_inv: assert np.allclose(pred_with[2], pred_without[2])
def test_down_up_samples(self): """Make sure raises an error if downsamples and upsamples have not the same number of layers""" with pytest.raises(ValueError): supervised_model_factory(downsample_filters=(2, ), upsample_filters=(2, 3)) with pytest.raises(ValueError): supervised_model_factory( downsample_filters=(2, 2, 2, 2, 2, 2, 2), upsample_filters=(2, 2, 2, 2, 2, 2, 2), )
def test_use_lambda(self, use_lambda, compute_inv): """Make sure the `use_lambda` flag is working""" losses = ("mse", "mse", "mse") if compute_inv else ("mse", "mse") losses_weights = (1, 1, 1) if compute_inv else (1, 1) model = supervised_model_factory( losses=losses, losses_weights=losses_weights, compute_inv=compute_inv, use_lambda=use_lambda, ) lambda_list = [ x for x in model.layers if isinstance(x, keras.layers.Lambda) ] if use_lambda: assert lambda_list else: assert not lambda_list
def test_default_construction(self): """Make sure possible to use with the default setting""" model = supervised_model_factory() assert isinstance(model, keras.Model)
def test_compute_external_metrics( self, monkeypatch, tmpdir, random_state, return_inverse ): evaluate_cache = [] def fake_evaluate(*args, **kwargs): evaluate_cache.append( { "deltas_true": args[0], "img_mov": args[2], "p": kwargs["p"], "deltas_true_inv": kwargs["deltas_true_inv"], } ) return pd.Series([2, 3]) monkeypatch.setattr( "atlalign.ml_utils.callbacks.evaluate_single", Mock(side_effect=fake_evaluate), ) monkeypatch.setattr("atlalign.ml_utils.callbacks.annotation_volume", Mock()) monkeypatch.setattr( "atlalign.ml_utils.io.nissl_volume", Mock(return_value=np.zeros((528, 320, 456, 1))), ) monkeypatch.setattr( "atlalign.ml_utils.callbacks.segmentation_collapsing_labels", Mock() ) n_samples = 10 n_val_samples = 4 h5_path = pathlib.Path(str(tmpdir)) / "temp.h5" self.create_h5(h5_path, n_samples, random_state) val_indexes = list(np.random.choice(n_samples, n_val_samples, replace=False)) val_gen = SupervisedGenerator( h5_path, indexes=val_indexes, shuffle=False, batch_size=1, return_inverse=return_inverse, ) losses = ["mse", "mse", "mse"] if return_inverse else ["mse", "mse"] losses_weights = [1, 1, 1] if return_inverse else [1, 1] model = supervised_model_factory( compute_inv=return_inverse, losses=losses, losses_weights=losses_weights, start_filters=(2,), downsample_filters=(4, 2), middle_filters=(2,), upsample_filters=(2, 4), ) df = MLFlowCallback.compute_external_metrics(model, val_gen) assert len(df) == len(val_indexes) assert np.allclose( df.index.values, load_dataset_in_memory(h5_path, "image_id")[val_indexes] ) assert len(evaluate_cache) == len(val_indexes) for ecache, val_index in zip(evaluate_cache, val_indexes): expected_deltas = load_dataset_in_memory(h5_path, "deltas_xy")[val_index] expected_deltas_inv = load_dataset_in_memory(h5_path, "inv_deltas_xy")[ val_index ] expected_image = load_dataset_in_memory(h5_path, "img")[val_index] / 255 expected_p = load_dataset_in_memory(h5_path, "p")[val_index] assert np.allclose(expected_deltas, ecache["deltas_true"]) assert np.allclose(expected_image, ecache["img_mov"]) assert np.allclose(expected_p, ecache["p"]) if return_inverse: assert np.allclose(expected_deltas_inv, ecache["deltas_true_inv"]) else: assert ecache["deltas_true_inv"] is None # they are not streamed