def test_merging_refined_maps(self):
        ny, nx = (3, 3)
        nav_size = ny * nx
        r = Rotation.from_euler(np.ones((nav_size, 3)))
        x = np.tile(np.arange(ny), nx)
        y = np.repeat(np.arange(nx), ny)

        # Simulation indices
        n_sim_indices = 10
        sim_indices1 = np.random.randint(low=0,
                                         high=1000,
                                         size=n_sim_indices *
                                         nav_size).reshape(
                                             (nav_size, n_sim_indices))
        sim_indices2 = np.random.randint(low=0,
                                         high=1000,
                                         size=n_sim_indices *
                                         nav_size).reshape(
                                             (nav_size, n_sim_indices))

        # Scores
        scores1 = np.ones(nav_size)
        scores1[0] = 3
        scores2 = 2 * np.ones(nav_size)

        xmap1 = CrystalMap(
            rotations=r,
            phase_id=np.ones(nav_size) * 0,
            phase_list=PhaseList(Phase(name="a")),
            x=x,
            y=y,
            prop={
                "simulation_indices": sim_indices1,
                "scores": scores1
            },
        )
        xmap2 = CrystalMap(
            rotations=r,
            phase_id=np.ones(nav_size),
            phase_list=PhaseList(Phase(name="b")),
            x=x,
            y=y,
            prop={
                "simulation_indices": sim_indices2,
                "scores": scores2
            },
        )
        xmap_merged = merge_crystal_maps(crystal_maps=[xmap1, xmap2])

        assert "simulation_indices" not in xmap_merged.prop.keys()
        assert "merged_simulation_indices" not in xmap_merged.prop.keys()

        with pytest.raises(ValueError, match="Cannot merge maps with more"):
            _ = merge_crystal_maps(
                crystal_maps=[xmap1, xmap2],
                simulation_indices_prop="simulation_indices",
            )
    def test_warning_merge_maps_with_same_phase(self, get_single_phase_xmap,
                                                phase_names,
                                                desired_phase_names):
        n_phases = len(phase_names)
        scores_prop = "scores"
        sim_idx_prop = "simulated_indices"
        map_shape = (5, 6)
        rot_per_point = 5

        xmaps = []
        xmap_args = (map_shape, rot_per_point, [scores_prop, sim_idx_prop])
        phase_ids = np.arange(n_phases)
        for i in range(n_phases):
            xmap = get_single_phase_xmap(*xmap_args, phase_names[i],
                                         phase_ids[i])
            # All maps have at least one point with the best score
            xmap[i, i].scores += i + 1
            xmaps.append(xmap)

        with pytest.warns(
                UserWarning,
                match=f"There are duplicates of phase {phase_names[0]}"):
            merged_xmap = merge_crystal_maps(
                crystal_maps=xmaps,
                scores_prop=scores_prop,
                simulation_indices_prop=sim_idx_prop,
            )

        assert all(
            [name in merged_xmap.phases.names for name in desired_phase_names])
    def test_property_names(self, get_single_phase_xmap, scores_prop,
                            sim_idx_prop):
        """Passing scores and simulation indices property names returns
        expected properties in merged map.
        """
        map_shape = (5, 6)
        rot_per_point = 50

        xmap1 = get_single_phase_xmap(map_shape, rot_per_point,
                                      [scores_prop, sim_idx_prop], "a", 0)
        xmap2 = get_single_phase_xmap(map_shape, rot_per_point,
                                      [scores_prop, sim_idx_prop], "b", 1)

        xmap2[3, 3].prop[scores_prop] = 2
        merged_xmap = merge_crystal_maps(
            crystal_maps=[xmap1, xmap2],
            greater_is_better=True,
            scores_prop=scores_prop,
            simulation_indices_prop=sim_idx_prop,
        )

        assert scores_prop in merged_xmap.prop.keys()
        assert sim_idx_prop in merged_xmap.prop.keys()

        desired_merged_shapes = (np.prod(map_shape), rot_per_point * 2)
        assert merged_xmap.prop[
            f"merged_{scores_prop}"].shape == desired_merged_shapes
        assert merged_xmap.prop[
            f"merged_{sim_idx_prop}"].shape == desired_merged_shapes
    def test_merging_maps_different_number_of_scores_raises(
            self, get_single_phase_xmap):
        nav_shape = (2, 3)
        xmap1 = get_single_phase_xmap(nav_shape, 3, name="a")
        xmap2 = get_single_phase_xmap(nav_shape, 4, name="b")
        xmap2[0, 1].scores = 2.0  # Both maps in both merged maps

        crystal_maps = [xmap1, xmap2]
        with pytest.raises(ValueError, match="All crystal maps must have the"):
            _ = merge_crystal_maps(crystal_maps)
    def test_mean_n_best_varying_scores(self, get_single_phase_xmap):
        """Ensure various combinations of scores per point and how many
        of these are evaulated to find the best match return expected
        results.
        """
        nav_shape = (2, 3)
        rot_per_point = 3
        xmap1 = get_single_phase_xmap(nav_shape, rot_per_point, name="a")
        xmap2 = get_single_phase_xmap(nav_shape, rot_per_point, name="b")
        idx = (0, 0)
        xmap1[idx].scores = [1, 2, 2.1]
        xmap2[idx].scores = [1, 1.9, 3]
        xmap2[0, 1].scores = 2.0  # Both maps in both merged maps

        crystal_maps = [xmap1, xmap2]
        merged_xmap1 = merge_crystal_maps(crystal_maps, mean_n_best=2)
        merged_xmap2 = merge_crystal_maps(crystal_maps, mean_n_best=3)

        assert np.allclose(merged_xmap1.phase_id, [0, 1, 0, 0, 0, 0])
        assert np.allclose(merged_xmap2.phase_id, [1, 1, 0, 0, 0, 0])
    def test_lower_is_better(self, get_single_phase_xmap):
        map_shape = (5, 6)
        rot_per_point = 5
        scores_prop = "scores"
        sim_idx_prop = "simulation_indices"

        xmap1 = get_single_phase_xmap(map_shape, rot_per_point,
                                      [scores_prop, sim_idx_prop], "a", 0)
        xmap2 = get_single_phase_xmap(map_shape, rot_per_point,
                                      [scores_prop, sim_idx_prop], "b", 1)

        xmap2[0, 3].prop[scores_prop] = 0
        desired_phase_id = np.zeros(np.prod(map_shape))
        desired_phase_id[3] = 1

        merged_xmap = merge_crystal_maps(
            crystal_maps=[xmap1, xmap2],
            greater_is_better=False,
            simulation_indices_prop=sim_idx_prop,
        )

        assert np.allclose(merged_xmap.phase_id, desired_phase_id)
    def test_mean_n_best(
        self,
        get_single_phase_xmap,
        nav_shape,
        rot_per_point,
        mean_n_best,
        desired_merged_scores,
        desired_merged_sim_idx,
    ):
        """Ensure that the merge sorted scores and simulation index
        properties in the merged map has the correct values and shape.
        """
        prop_names = ["scores", "simulation_indices"]
        n_phases = np.shape(desired_merged_scores)[-1] // rot_per_point
        xmaps = []
        for i in range(n_phases):
            xmap = get_single_phase_xmap(nav_shape,
                                         rot_per_point,
                                         name=str(i),
                                         prop_names=prop_names)
            xmap[i].scores += i
            xmaps.append(xmap)

        # The simulation indices should be the same in all maps
        all_sim_idx = np.dstack([xmap.simulation_indices for xmap in xmaps])
        assert np.sum(np.diff(all_sim_idx)) == 0

        merged_xmap = merge_crystal_maps(
            crystal_maps=xmaps,
            mean_n_best=mean_n_best,
            simulation_indices_prop=prop_names[1],
        )

        assert merged_xmap.phases.size == n_phases
        assert np.allclose(merged_xmap.merged_scores, desired_merged_scores)
        assert np.allclose(merged_xmap.merged_simulation_indices,
                           desired_merged_sim_idx)
 def test_merging_maps_different_shapes_raises(self, get_single_phase_xmap):
     xmap1 = get_single_phase_xmap((4, 3))
     xmap2 = get_single_phase_xmap((3, 4))
     with pytest.raises(ValueError, match="All crystal maps must have the"):
         _ = merge_crystal_maps([xmap1, xmap2])
    def test_merge_crystal_maps_1d(self, map_shape, rot_per_point, phase_names,
                                   get_single_phase_xmap):
        """Crystal maps with a 1D navigation shape can be merged
        successfully and yields an expected output.
        """
        n_phases = len(phase_names)
        scores_prop, sim_idx_prop = "scores", "sim_idx"

        map_size = np.sum(map_shape)
        data_shape = (map_size, )
        if rot_per_point > 1:
            data_shape += (rot_per_point, )

        desired_phase_ids = np.zeros(map_size)
        desired_scores = np.ones(data_shape)
        desired_idx = np.arange(np.prod(data_shape)).reshape(data_shape)

        xmaps = []
        xmap_kwargs = dict(
            nav_shape=map_shape,
            rotations_per_point=rot_per_point,
            prop_names=[scores_prop, sim_idx_prop],
            step_sizes=(1, ),
        )
        phase_ids = np.arange(n_phases)
        for i in range(n_phases):
            xmap = get_single_phase_xmap(name=phase_names[i],
                                         phase_id=phase_ids[i],
                                         **xmap_kwargs)
            # All maps have at least one point with the best score
            xmap[i].prop[scores_prop] += i + 1
            xmaps.append(xmap)

            desired_phase_ids[i] = i
            desired_scores[i] = xmap[i].prop[scores_prop]
            desired_idx[i] = xmap[i].prop[sim_idx_prop]

            if i == 0:
                desired_rot = xmap.rotations.data
            else:
                desired_rot[i] = xmap[i].rotations.data

        merged_xmap = merge_crystal_maps(
            crystal_maps=xmaps,
            scores_prop=scores_prop,
            simulation_indices_prop=sim_idx_prop,
        )

        assert merged_xmap.shape == xmaps[0].shape
        assert merged_xmap.size == xmaps[0].size
        for v1, v2 in zip(merged_xmap._coordinates.values(),
                          xmaps[0]._coordinates.values()):
            if v1 is None:
                assert v1 is v2
            else:
                np.allclose(v1, v2)

        assert np.allclose(merged_xmap.phase_id, desired_phase_ids)
        assert np.allclose(merged_xmap.prop[scores_prop], desired_scores)
        assert np.allclose(merged_xmap.prop[sim_idx_prop], desired_idx)
        assert np.allclose(merged_xmap.rotations.data, desired_rot)

        desired_merged_shapes = (map_size, rot_per_point * n_phases)
        assert merged_xmap.prop[
            f"merged_{scores_prop}"].shape == desired_merged_shapes
        assert merged_xmap.prop[
            f"merged_{sim_idx_prop}"].shape == desired_merged_shapes
    def test_merge_crystal_maps_2d(
        self,
        get_single_phase_xmap,
        map_shape,
        rot_per_point,
        phase_names,
        mean_n_best,
    ):
        """Crystal maps with a 2D navigation shape can be merged
        successfully and yields an expected output.
        """
        n_phases = len(phase_names)
        scores_prop, sim_idx_prop = "scores", "sim_idx"

        map_size = np.prod(map_shape)
        data_shape = (map_size, )
        if rot_per_point > 1:
            data_shape += (rot_per_point, )

        desired_phase_ids = np.zeros(map_size)
        desired_scores = np.ones(data_shape)
        desired_idx = np.arange(np.prod(data_shape)).reshape(data_shape)

        xmaps = []
        xmap_args = (map_shape, rot_per_point, [scores_prop, sim_idx_prop])
        phase_ids = np.arange(n_phases)
        ny, nx = map_shape
        for i in range(n_phases):
            xmap = get_single_phase_xmap(*xmap_args, phase_names[i],
                                         phase_ids[i])
            # All maps have at least one point with the best score along
            # the map diagonal
            idx = (i, i)
            xmap[idx].prop[scores_prop] += i + 1
            xmaps.append(xmap)

            j = i * (1 + nx)
            desired_phase_ids[j] = i
            desired_scores[j] = xmap[idx].prop[scores_prop]
            desired_idx[j] = xmap[idx].prop[sim_idx_prop]

            if i == 0:
                desired_rot = xmap.rotations.data
            else:
                desired_rot[j] = xmap[idx].rotations.data

        merged_xmap = merge_crystal_maps(
            crystal_maps=xmaps,
            mean_n_best=mean_n_best,
            scores_prop=scores_prop,
            simulation_indices_prop=sim_idx_prop,
        )

        assert merged_xmap.shape == xmaps[0].shape
        assert merged_xmap.size == xmaps[0].size
        for v1, v2 in zip(merged_xmap._coordinates.values(),
                          xmaps[0]._coordinates.values()):
            if v1 is None:
                assert v1 is v2
            else:
                np.allclose(v1, v2)

        assert np.allclose(merged_xmap.phase_id, desired_phase_ids)
        assert np.allclose(merged_xmap.prop[scores_prop], desired_scores)
        assert np.allclose(merged_xmap.prop[sim_idx_prop], desired_idx)
        assert np.allclose(merged_xmap.rotations.data, desired_rot)

        desired_merged_shapes = (map_size, rot_per_point * n_phases)
        assert merged_xmap.prop[
            f"merged_{scores_prop}"].shape == desired_merged_shapes
        assert merged_xmap.prop[
            f"merged_{sim_idx_prop}"].shape == desired_merged_shapes