def test_metric_not_compatible_with_data(self):
     self.dummy_metric.scope = MetricScope.ONE_TO_MANY
     with pytest.raises(OSError):
         _pattern_match(
             np.zeros((2, 2, 2, 2)),
             np.zeros((2, 2)),
             metric=self.dummy_metric,
         )
    def test_pattern_match_phase_name(self):
        """Ensure that the `phase_name` accepts different types."""
        exp = nickel_ebsd_small().data
        sim = exp.reshape((-1, ) + exp.shape[-2:])

        sim_idx1, scores1 = _pattern_match(exp, sim, n_slices=2)
        sim_idx2, scores2 = _pattern_match(exp,
                                           sim,
                                           phase_name="a",
                                           n_slices=2)
        sim_idx3, scores3 = _pattern_match(exp, sim, phase_name="", n_slices=2)

        assert np.allclose(sim_idx1[0], [0, 3, 6, 4, 7, 1, 8, 5, 2])
        assert np.allclose(sim_idx2[0], [0, 3, 6, 4, 7, 1, 8, 5, 2])
        assert np.allclose(sim_idx3[0], [0, 3, 6, 4, 7, 1, 8, 5, 2])
 def test_pattern_match_compute_true(self, n_slices):
     # Four patterns
     p = np.array(
         [
             [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
             [[[9, 8], [1, 7]], [[5, 2], [2, 7]]],
         ],
         np.int8,
     )
     # Five templates
     t = np.array(
         [
             [[5, 3], [2, 7]],
             [[9, 8], [1, 7]],
             [[10, 2], [5, 3]],
             [[8, 4], [6, 12]],
             [[43, 0], [5, 3]],
         ],
         np.int8,
     )
     t_da = da.from_array(t)
     mr = _pattern_match(p, t_da, n_slices=n_slices, keep_n=1)
     assert mr[0][2] == 1  # Template index in t of perfect match
     assert np.allclose(mr[1][2], 1.0)  # ZNCC of perfect match
    def __call__(
        self,
        signal,
        metric: Union[str, SimilarityMetric] = "ncc",
        keep_n: int = 50,
        n_slices: int = 1,
        return_merged_crystal_map: bool = False,
        get_orientation_similarity_map: bool = False,
    ) -> Union[CrystalMap, List[CrystalMap]]:
        """Match each experimental pattern to all simulated patterns, of
        known crystal orientations in pre-computed dictionaries
        :cite:`chen2015dictionary,jackson2019dictionary`, to determine
        their phase and orientation.

        A suitable similarity metric, the normalized cross-correlation
        (:func:`~kikuchipy.indexing.similarity_metrics.ncc`), is used by
        default, but a valid user-defined similarity metric may be used
        instead.

        :class:`~orix.crystal_map.crystal_map.CrystalMap`'s for each
        dictionary with "scores" and "simulation_indices" as properties
        are returned.

        Parameters
        ----------
        signal : EBSD
            EBSD signal with experimental patterns.
        metric : str or SimilarityMetric, optional
            Similarity metric, by default "ncc" (normalized
            cross-correlation).
        keep_n : int, optional
            Number of best matches to keep, by default 50 or the number
            of simulated patterns if fewer than 50 are available.
        n_slices : int, optional
            Number of simulation slices to process sequentially, by
            default 1 (no slicing).
        return_merged_crystal_map : bool, optional
            Whether to return a merged crystal map, the best matches
            determined from the similarity scores, in addition to the
            single phase maps. By default False.
        get_orientation_similarity_map : bool, optional
            Add orientation similarity maps to the returned crystal
            maps' properties named "osm". By default False.

        Returns
        -------
        xmaps : :class:`~orix.crystal_map.crystal_map.CrystalMap` or \
                list of \
                :class:`~orix.crystal_map.crystal_map.CrystalMap`
            A crystal map for each dictionary loaded and one merged map
            if `return_merged_crystal_map = True`.

        Notes
        -----
        Merging of crystal maps and calculations of orientation
        similarity maps can be done afterwards with
        :func:`~kikuchipy.indexing.merge_crystal_maps` and
        :func:`~kikuchipy.indexing.orientation_similarity_map`,
        respectively.

        See Also
        --------
        ~kikuchipy.indexing.similarity_metrics.make_similarity_metric
        ~kikuchipy.indexing.similarity_metrics.ndp
        """
        # This needs a rework before sent to cluster and possibly more
        # automatic slicing with dask
        n_simulations = max(
            [d.axes_manager.navigation_size for d in self.dictionaries]
        )
        good_number = 13500
        if (n_simulations // n_slices) > good_number:
            answer = input(
                "You should probably increase n_slices depending on your "
                f"available memory, try above {n_simulations // good_number}."
                " Do you want to proceed? [y/n]"
            )
            if answer != "y":
                return

        # Get metric from optimized metrics if it is available, or
        # return the metric if it is not
        metric = _SIMILARITY_METRICS.get(metric, metric)

        axes_manager = signal.axes_manager
        spatial_arrays = _get_spatial_arrays(
            shape=axes_manager.navigation_shape,
            extent=axes_manager.navigation_extent,
        )
        n_nav_dims = axes_manager.navigation_dimension
        if n_nav_dims == 0:
            xmap_kwargs = dict()
        elif n_nav_dims == 1:
            scan_unit = axes_manager.navigation_axes[0].units
            xmap_kwargs = dict(x=spatial_arrays, scan_unit=scan_unit)
        else:  # 2d
            scan_unit = axes_manager.navigation_axes[0].units
            xmap_kwargs = dict(
                x=spatial_arrays[0], y=spatial_arrays[1], scan_unit=scan_unit,
            )

        keep_n = min([keep_n] + [d.xmap.size for d in self.dictionaries])

        # Naively let dask compute them seperately, should try in the
        # future combined compute for better performance
        xmaps = []
        patterns = signal.data
        for dictionary in self.dictionaries:
            simulation_indices, scores = _pattern_match(
                patterns,
                dictionary.data,
                metric=metric,
                keep_n=keep_n,
                n_slices=n_slices,
                phase_name=dictionary.xmap.phases_in_data.names[0],
            )
            new_xmap = CrystalMap(
                rotations=dictionary.xmap.rotations[simulation_indices],
                phase_list=dictionary.xmap.phases_in_data,
                prop={
                    "scores": scores,
                    "simulation_indices": simulation_indices,
                },
                **xmap_kwargs,
            )
            xmaps.append(new_xmap)

        # Create a merged CrystalMap using best metric result across all
        # dictionaries
        if return_merged_crystal_map and len(self.dictionaries) > 1:
            xmap_merged = merge_crystal_maps(xmaps, metric=metric)
            xmaps.append(xmap_merged)

        # Compute orientation similarity map
        if get_orientation_similarity_map:
            for xmap in xmaps:
                osm = orientation_similarity_map(xmap, n_best=keep_n)
                xmap.prop["osm"] = osm.flatten()

        if len(xmaps) == 1:
            xmaps = xmaps[0]

        return xmaps
 def test_pattern_match_slices_compute_false(self):
     p = np.arange(16).reshape((2, 2, 2, 2))
     t = np.arange(8).reshape((2, 2, 2))
     with pytest.raises(NotImplementedError):
         _pattern_match(p, t, n_slices=2, compute=False)
 def test_pattern_match_compute_false(self):
     p = np.arange(16).reshape((2, 2, 2, 2))
     t = np.arange(8).reshape((2, 2, 2))
     mr = _pattern_match(p, t, compute=False)
     assert len(mr) == 2
     assert isinstance(mr[0], da.Array) and isinstance(mr[1], da.Array)
 def test_mismatching_signal_shapes(self):
     self.dummy_metric.scope = MetricScope.MANY_TO_MANY
     with pytest.raises(OSError):
         _pattern_match(np.zeros((2, 2)),
                        np.zeros((3, 3)),
                        metric=self.dummy_metric)
 def test_not_recognized_metric(self):
     with pytest.raises(ValueError):
         _pattern_match(np.zeros((2, 2)),
                        np.zeros((2, 2)),
                        metric="not_recognized")
 def test_pattern_match_one_to_one(self):
     p = np.random.random(3 * 3).reshape((3, 3))
     mr = _pattern_match(p, p)
     assert mr[0][0] == 0