Exemplo n.º 1
0
    def __init__(self, reflections, experiments, params):
        self.reflections = reflections
        self.experiments = experiments

        self.params = params.indexing
        self.all_params = params
        self.refined_experiments = None
        self.hkl_offset = None

        if self.params.index_assignment.method == "local":
            self._assign_indices = assign_indices.AssignIndicesLocal(
                epsilon=self.params.index_assignment.local.epsilon,
                delta=self.params.index_assignment.local.delta,
                l_min=self.params.index_assignment.local.l_min,
                nearest_neighbours=self.params.index_assignment.local.
                nearest_neighbours,
            )
        else:
            self._assign_indices = assign_indices.AssignIndicesGlobal(
                tolerance=self.params.index_assignment.simple.hkl_tolerance)

        if self.all_params.refinement.reflections.outlier.algorithm in (
                "auto",
                libtbx.Auto,
        ):
            if self.experiments[0].goniometer is None:
                self.all_params.refinement.reflections.outlier.algorithm = "sauter_poon"
            else:
                # different default to dials.refine
                # tukey is faster and more appropriate at the indexing step
                self.all_params.refinement.reflections.outlier.algorithm = "tukey"

        for expt in self.experiments[1:]:
            if expt.detector.is_similar_to(self.experiments[0].detector):
                expt.detector = self.experiments[0].detector
            if expt.goniometer is not None and expt.goniometer.is_similar_to(
                    self.experiments[0].goniometer):
                expt.goniometer = self.experiments[0].goniometer
                # can only share a beam if we share a goniometer?
                if expt.beam.is_similar_to(self.experiments[0].beam):
                    expt.beam = self.experiments[0].beam
                if self.params.combine_scans and expt.scan == self.experiments[
                        0].scan:
                    expt.scan = self.experiments[0].scan

        if "flags" in self.reflections:
            strong_sel = self.reflections.get_flags(
                self.reflections.flags.strong)
            if strong_sel.count(True) > 0:
                self.reflections = self.reflections.select(strong_sel)
        if "flags" not in self.reflections or strong_sel.count(True) == 0:
            # backwards compatibility for testing
            self.reflections.set_flags(
                flex.size_t_range(len(self.reflections)),
                self.reflections.flags.strong)

        self._setup_symmetry()
        self.d_min = None

        self.setup_indexing()
def run(experiments, reflections, random_seed=42):
    scitbx.random.set_random_seed(random_seed)
    random.seed(random_seed)

    reflections["id"] = flex.int(len(reflections), 0)
    reflections = reflections.select(
        reflections.get_flags(reflections.flags.indexed))

    beam = experiments[0].beam
    detector = experiments[0].detector
    p_id, (x, y) = detector.get_ray_intersection(beam.get_s0())

    g = scitbx.random.variate(
        scitbx.random.normal_distribution(mean=0, sigma=2))

    n = 100
    shift_x = g(n)
    shift_y = g(n)

    expected_miller_indices = reflections["miller_index"]
    non_zero_sel = expected_miller_indices != (0, 0, 0)

    misindexed_global = flex.size_t()
    correct_global = flex.size_t()
    misindexed_local = flex.size_t()
    correct_local = flex.size_t()

    global_timer = time_log("global")
    local_timer = time_log("local")

    for d_x, d_y in zip(shift_x, shift_y):
        set_slow_fast_beam_centre_mm(detector, beam, (y + d_y, x + d_x), p_id)

        refl = Indexer.map_centroids_to_reciprocal_space(
            experiments, reflections)

        refl_global = copy.deepcopy(refl)
        refl_global["id"] = flex.int(len(refl), -1)
        global_timer.start()
        assign_indices.AssignIndicesGlobal()(refl_global, experiments)
        global_timer.stop()

        misindexed_global.append(
            (expected_miller_indices == refl_global["miller_index"]
             ).select(non_zero_sel).count(False))
        correct_global.append(
            (expected_miller_indices == refl_global["miller_index"]
             ).select(non_zero_sel).count(True))

        refl_local = copy.deepcopy(refl)
        refl_local["id"] = flex.int(len(refl), -1)
        local_timer.start()
        assign_indices.AssignIndicesLocal()(refl_local, experiments)
        local_timer.stop()

        misindexed_local.append(
            (expected_miller_indices == refl_local["miller_index"]
             ).select(non_zero_sel).count(False))
        correct_local.append(
            (expected_miller_indices == refl_local["miller_index"]
             ).select(non_zero_sel).count(True))

        print("Beam centre shift: (%.2f, %.2f)" % (d_x, d_y))
        print("Misindexed global: %i" % misindexed_global[-1])
        print("Correct global: %i" % correct_global[-1])
        print("Misindexed local: %i" % misindexed_local[-1])
        print("Correct local: %i" % correct_local[-1])
        print()

    print(global_timer.legend)
    print(global_timer.report())
    print(local_timer.report())

    vmax = max(flex.max(correct_global), flex.max(correct_local))

    import matplotlib

    matplotlib.use("Agg")
    from matplotlib import pyplot as plt

    fig, axes = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
    sc = axes[0].scatter(
        shift_x,
        shift_y,
        vmin=0,
        vmax=1,
        c=correct_global.as_double() / vmax,
        cmap="viridis",
    )
    sc = axes[1].scatter(
        shift_x,
        shift_y,
        vmin=0,
        vmax=1,
        c=correct_local.as_double() / vmax,
        cmap="viridis",
    )
    axes[0].set_title("global")
    axes[1].set_title("local")
    for ax in axes:
        ax.set_aspect("equal")
        ax.set_xlabel("beam centre shift (mm)")
    axes[0].set_ylabel("beam centre shift (mm)")

    cbar = plt.colorbar(sc, ax=axes, shrink=0.5)
    cbar.set_label("Fraction correctly indexed")
    plt.savefig("correctly_indexed.png")