Esempio n. 1
0
 def _setup_symmetry(self):
     target_unit_cell = self.params.known_symmetry.unit_cell
     target_space_group = self.params.known_symmetry.space_group
     if target_space_group is not None:
         target_space_group = target_space_group.group()
     self._symmetry_handler = SymmetryHandler(
         unit_cell=target_unit_cell,
         space_group=target_space_group,
         max_delta=self.params.known_symmetry.max_delta,
     )
     return
Esempio n. 2
0
class Indexer(object):
    def __init__(self, reflections, experiments, params):
        self.reflections = reflections
        self.experiments = experiments

        self.params = params.indexing
        self.all_params = params
        self.refined_experiments = None
        self.hkl_offset = None

        if self.params.index_assignment.method == "local":
            self._assign_indices = assign_indices.AssignIndicesLocal(
                epsilon=self.params.index_assignment.local.epsilon,
                delta=self.params.index_assignment.local.delta,
                l_min=self.params.index_assignment.local.l_min,
                nearest_neighbours=self.params.index_assignment.local.
                nearest_neighbours,
            )
        else:
            self._assign_indices = assign_indices.AssignIndicesGlobal(
                tolerance=self.params.index_assignment.simple.hkl_tolerance)

        if self.all_params.refinement.reflections.outlier.algorithm in (
                "auto",
                libtbx.Auto,
        ):
            if self.experiments[0].goniometer is None:
                self.all_params.refinement.reflections.outlier.algorithm = "sauter_poon"
            else:
                # different default to dials.refine
                # tukey is faster and more appropriate at the indexing step
                self.all_params.refinement.reflections.outlier.algorithm = "tukey"

        for expt in self.experiments[1:]:
            if expt.detector.is_similar_to(self.experiments[0].detector):
                expt.detector = self.experiments[0].detector
            if expt.goniometer is not None and expt.goniometer.is_similar_to(
                    self.experiments[0].goniometer):
                expt.goniometer = self.experiments[0].goniometer
                # can only share a beam if we share a goniometer?
                if expt.beam.is_similar_to(self.experiments[0].beam):
                    expt.beam = self.experiments[0].beam
                if self.params.combine_scans and expt.scan == self.experiments[
                        0].scan:
                    expt.scan = self.experiments[0].scan

        if "flags" in self.reflections:
            strong_sel = self.reflections.get_flags(
                self.reflections.flags.strong)
            if strong_sel.count(True) > 0:
                self.reflections = self.reflections.select(strong_sel)
        if "flags" not in self.reflections or strong_sel.count(True) == 0:
            # backwards compatibility for testing
            self.reflections.set_flags(
                flex.size_t_range(len(self.reflections)),
                self.reflections.flags.strong)

        self._setup_symmetry()
        self.d_min = None

        self.setup_indexing()

    @staticmethod
    def from_parameters(reflections,
                        experiments,
                        known_crystal_models=None,
                        params=None):

        if known_crystal_models is not None:
            from dials.algorithms.indexing.known_orientation import (
                IndexerKnownOrientation, )

            if params.indexing.known_symmetry.space_group is None:
                params.indexing.known_symmetry.space_group = (
                    known_crystal_models[0].get_space_group().info())
            idxr = IndexerKnownOrientation(reflections, experiments, params,
                                           known_crystal_models)
        else:
            has_stills = False
            has_sequences = False
            for expt in experiments:
                if isinstance(expt.imageset, ImageSequence):
                    has_sequences = True
                else:
                    has_stills = True

            if has_stills and has_sequences:
                raise ValueError(
                    "Please provide only stills or only sequences, not both")

            use_stills_indexer = has_stills

            if not (params.indexing.stills.indexer is libtbx.Auto
                    or params.indexing.stills.indexer.lower() == "auto"):
                if params.indexing.stills.indexer == "stills":
                    use_stills_indexer = True
                elif params.indexing.stills.indexer == "sequences":
                    use_stills_indexer = False
                else:
                    assert False

            if params.indexing.basis_vector_combinations.max_refine is libtbx.Auto:
                if use_stills_indexer:
                    params.indexing.basis_vector_combinations.max_refine = 5
                else:
                    params.indexing.basis_vector_combinations.max_refine = 50

            if use_stills_indexer:
                # Ensure the indexer and downstream applications treat this as set of stills
                from dxtbx.imageset import ImageSet  # , MemImageSet

                for experiment in experiments:
                    experiment.imageset = ImageSet(
                        experiment.imageset.data(),
                        experiment.imageset.indices())
                    # if isinstance(imageset, MemImageSet):
                    #   imageset = MemImageSet(imagesequence._images, imagesequence.indices())
                    # else:
                    #   imageset = ImageSet(imagesequence.reader(), imagesequence.indices())
                    #   imageset._models = imagesequence._models
                    experiment.imageset.set_scan(None)
                    experiment.imageset.set_goniometer(None)
                    experiment.scan = None
                    experiment.goniometer = None

            IndexerType = None
            for entry_point in pkg_resources.iter_entry_points(
                    "dials.index.basis_vector_search"):
                if params.indexing.method == entry_point.name:
                    if use_stills_indexer:
                        # do something
                        from dials.algorithms.indexing.stills_indexer import (
                            StillsIndexerBasisVectorSearch as IndexerType, )
                    else:
                        from dials.algorithms.indexing.lattice_search import (
                            BasisVectorSearch as IndexerType, )

            if IndexerType is None:
                for entry_point in pkg_resources.iter_entry_points(
                        "dials.index.lattice_search"):
                    if params.indexing.method == entry_point.name:
                        if use_stills_indexer:
                            from dials.algorithms.indexing.stills_indexer import (
                                StillsIndexerLatticeSearch as IndexerType, )
                        else:
                            from dials.algorithms.indexing.lattice_search import (
                                LatticeSearch as IndexerType, )

            assert IndexerType is not None

            idxr = IndexerType(reflections, experiments, params=params)

        return idxr

    def _setup_symmetry(self):
        target_unit_cell = self.params.known_symmetry.unit_cell
        target_space_group = self.params.known_symmetry.space_group
        if target_space_group is not None:
            target_space_group = target_space_group.group()
        else:
            target_space_group = sgtbx.space_group()
        self._symmetry_handler = SymmetryHandler(
            unit_cell=target_unit_cell,
            space_group=target_space_group,
            max_delta=self.params.known_symmetry.max_delta,
        )
        return

    def setup_indexing(self):
        if len(self.reflections) == 0:
            raise DialsIndexError("No reflections left to index!")

        if "imageset_id" not in self.reflections:
            self.reflections["imageset_id"] = self.reflections["id"]
        self.reflections.centroid_px_to_mm(self.experiments)
        self.reflections.map_centroids_to_reciprocal_space(self.experiments)
        self.reflections.calculate_entering_flags(self.experiments)

        self.find_max_cell()

        if self.params.sigma_phi_deg is not None:
            var_x, var_y, _ = self.reflections["xyzobs.mm.variance"].parts()
            var_phi_rad = flex.double(
                var_x.size(), (math.pi / 180 * self.params.sigma_phi_deg)**2)
            self.reflections["xyzobs.mm.variance"] = flex.vec3_double(
                var_x, var_y, var_phi_rad)

        if self.params.debug:
            self._debug_write_reciprocal_lattice_points_as_pdb()

        self.reflections["id"] = flex.int(len(self.reflections), -1)

    def index(self):
        experiments = ExperimentList()

        had_refinement_error = False
        have_similar_crystal_models = False

        while True:
            if had_refinement_error or have_similar_crystal_models:
                break
            max_lattices = self.params.multiple_lattice_search.max_lattices
            if max_lattices is not None and len(experiments) >= max_lattices:
                break
            if len(experiments) > 0:
                cutoff_fraction = (self.params.multiple_lattice_search.
                                   recycle_unindexed_reflections_cutoff)
                d_spacings = 1 / self.reflections["rlp"].norms()
                d_min_indexed = flex.min(
                    d_spacings.select(self.indexed_reflections))
                min_reflections_for_indexing = cutoff_fraction * len(
                    self.reflections.select(d_spacings > d_min_indexed))
                crystal_ids = self.reflections.select(
                    d_spacings > d_min_indexed)["id"]
                if (crystal_ids
                        == -1).count(True) < min_reflections_for_indexing:
                    logger.info(
                        "Finish searching for more lattices: %i unindexed reflections remaining."
                        % ((crystal_ids == -1).count(True)))
                    break

            n_lattices_previous_cycle = len(experiments)

            if self.d_min is None:
                self.d_min = self.params.refinement_protocol.d_min_start

            if len(experiments) == 0:
                new_expts = self.find_lattices()
                generate_experiment_identifiers(new_expts)
                experiments.extend(new_expts)
            else:
                try:
                    new = self.find_lattices()
                    generate_experiment_identifiers(new)
                    experiments.extend(new)
                except DialsIndexError:
                    logger.info("Indexing remaining reflections failed")

            if self.params.refinement_protocol.d_min_step is libtbx.Auto:
                n_cycles = self.params.refinement_protocol.n_macro_cycles
                if self.d_min is None or n_cycles == 1:
                    self.params.refinement_protocol.d_min_step = 0
                else:
                    d_spacings = 1 / self.reflections["rlp"].norms()
                    d_min_all = flex.min(d_spacings)
                    self.params.refinement_protocol.d_min_step = (
                        self.d_min - d_min_all) / (n_cycles - 1)
                    logger.info("Using d_min_step %.1f" %
                                self.params.refinement_protocol.d_min_step)

            if len(experiments) == 0:
                raise DialsIndexError("No suitable lattice could be found.")
            elif len(experiments) == n_lattices_previous_cycle:
                # no more lattices found
                break

            for i_cycle in range(
                    self.params.refinement_protocol.n_macro_cycles):
                if (i_cycle > 0 and self.d_min is not None
                        and self.params.refinement_protocol.d_min_step > 0):
                    d_min = self.d_min - self.params.refinement_protocol.d_min_step
                    d_min = max(d_min, 0)
                    if self.params.refinement_protocol.d_min_final is not None:
                        d_min = max(
                            d_min, self.params.refinement_protocol.d_min_final)
                    if d_min >= 0:
                        self.d_min = d_min
                        logger.info("Increasing resolution to %.2f Angstrom" %
                                    d_min)

                # reset reflection lattice flags
                # the lattice a given reflection belongs to: a value of -1 indicates
                # that a reflection doesn't belong to any lattice so far
                self.reflections["id"] = flex.int(len(self.reflections), -1)

                self.index_reflections(experiments, self.reflections)

                if i_cycle == 0 and self.params.known_symmetry.space_group is not None:
                    self._apply_symmetry_post_indexing(
                        experiments, self.reflections,
                        n_lattices_previous_cycle)

                logger.info("\nIndexed crystal models:")
                self.show_experiments(experiments,
                                      self.reflections,
                                      d_min=self.d_min)

                if self._check_have_similar_crystal_models(experiments):
                    have_similar_crystal_models = True
                    break

                logger.info("")
                logger.info("#" * 80)
                logger.info("Starting refinement (macro-cycle %i)" %
                            (i_cycle + 1))
                logger.info("#" * 80)
                logger.info("")
                self.indexed_reflections = self.reflections["id"] > -1

                sel = flex.bool(len(self.reflections), False)
                lengths = 1 / self.reflections["rlp"].norms()
                if self.d_min is not None:
                    isel = (lengths <= self.d_min).iselection()
                    sel.set_selected(isel, True)
                sel.set_selected(self.reflections["id"] == -1, True)
                self.reflections.unset_flags(sel,
                                             self.reflections.flags.indexed)
                self.unindexed_reflections = self.reflections.select(sel)

                reflections_for_refinement = self.reflections.select(
                    self.indexed_reflections)
                if self.params.refinement_protocol.mode == "repredict_only":
                    refined_experiments, refined_reflections = (
                        experiments,
                        reflections_for_refinement,
                    )
                    from dials.algorithms.refinement.prediction.managed_predictors import (
                        ExperimentsPredictorFactory, )

                    ref_predictor = ExperimentsPredictorFactory.from_experiments(
                        experiments,
                        spherical_relp=self.all_params.refinement.
                        parameterisation.spherical_relp_model,
                    )
                    ref_predictor(refined_reflections)
                else:
                    try:
                        refined_experiments, refined_reflections = self.refine(
                            experiments, reflections_for_refinement)
                    except (DialsRefineConfigError,
                            DialsRefineRuntimeError) as e:
                        if len(experiments) == 1:
                            raise DialsIndexRefineError(str(e))
                        had_refinement_error = True
                        logger.info("Refinement failed:")
                        logger.info(e)
                        del experiments[-1]

                        # remove experiment id from the reflections associated
                        # with this deleted experiment - indexed flag removed
                        # below
                        last = len(experiments)
                        sel = refined_reflections["id"] == last
                        logger.info("Removing %d reflections with id %d" %
                                    (sel.count(True), last))
                        refined_reflections["id"].set_selected(sel, -1)

                        break

                self._unit_cell_volume_sanity_check(experiments,
                                                    refined_experiments)

                self.refined_reflections = refined_reflections
                self.refined_reflections.unset_flags(
                    self.refined_reflections["id"] < 0,
                    self.refined_reflections.flags.indexed,
                )

                for i, expt in enumerate(self.experiments):
                    ref_sel = self.refined_reflections.select(
                        self.refined_reflections["imageset_id"] == i)
                    ref_sel = ref_sel.select(ref_sel["id"] >= 0)
                    for i_expt in set(ref_sel["id"]):
                        refined_expt = refined_experiments[i_expt]
                        expt.detector = refined_expt.detector
                        expt.beam = refined_expt.beam
                        expt.goniometer = refined_expt.goniometer
                        expt.scan = refined_expt.scan
                        refined_expt.imageset = expt.imageset

                if not (self.all_params.refinement.parameterisation.beam.fix
                        == "all" and self.all_params.refinement.
                        parameterisation.detector.fix == "all"):
                    # Experimental geometry may have changed - re-map centroids to
                    # reciprocal space
                    self.reflections.map_centroids_to_reciprocal_space(
                        self.experiments)

                # update for next cycle
                experiments = refined_experiments
                self.refined_experiments = refined_experiments

                logger.info("\nRefined crystal models:")
                self.show_experiments(self.refined_experiments,
                                      self.reflections,
                                      d_min=self.d_min)

                if (i_cycle >= 2 and self.d_min
                        == self.params.refinement_protocol.d_min_final):
                    logger.info(
                        "Target d_min_final reached: finished with refinement")
                    break

        if self.refined_experiments is None:
            raise DialsIndexRefineError(
                "None of the experiments could refine.")

        if len(self.refined_experiments) > 1:
            from dials.algorithms.indexing.compare_orientation_matrices import (
                rotation_matrix_differences, )

            logger.info(
                rotation_matrix_differences(
                    self.refined_experiments.crystals()))

        self._xyzcal_mm_to_px(self.refined_experiments,
                              self.refined_reflections)

    def _unit_cell_volume_sanity_check(self, original_experiments,
                                       refined_experiments):
        # sanity check for unrealistic unit cell volume increase during refinement
        # usually this indicates too many parameters are being refined given the
        # number of observations provided.
        if not self.params.refinement_protocol.disable_unit_cell_volume_sanity_check:
            for orig_expt, refined_expt in zip(original_experiments,
                                               refined_experiments):
                uc1 = orig_expt.crystal.get_unit_cell()
                uc2 = refined_expt.crystal.get_unit_cell()
                volume_change = abs(uc1.volume() - uc2.volume()) / uc1.volume()
                cutoff = 0.5
                if volume_change > cutoff:
                    msg = ("\n".join((
                        "Unrealistic unit cell volume increase during refinement of %.1f%%.",
                        "Please try refining fewer parameters, either by enforcing symmetry",
                        "constraints (space_group=) and/or disabling experimental geometry",
                        "refinement (detector.fix=all and beam.fix=all). To disable this",
                        "sanity check set disable_unit_cell_volume_sanity_check=True.",
                    )) % (100 * volume_change))
                    raise DialsIndexError(msg)

    def _apply_symmetry_post_indexing(self, experiments, reflections,
                                      n_lattices_previous_cycle):
        # now apply the space group symmetry only after the first indexing
        # need to make sure that the symmetrized orientation is similar to the P1 model
        for cryst in experiments.crystals()[n_lattices_previous_cycle:]:
            new_cryst, cb_op = self._symmetry_handler.apply_symmetry(cryst)
            new_cryst = new_cryst.change_basis(cb_op)
            cryst.update(new_cryst)
            cryst.set_space_group(
                self.params.known_symmetry.space_group.group())
            for i_expt, expt in enumerate(experiments):
                if expt.crystal is not cryst:
                    continue
                if not cb_op.is_identity_op():
                    miller_indices = reflections["miller_index"].select(
                        reflections["id"] == i_expt)
                    miller_indices = cb_op.apply(miller_indices)
                    reflections["miller_index"].set_selected(
                        reflections["id"] == i_expt, miller_indices)

    def _check_have_similar_crystal_models(self, experiments):
        """
        Checks for similar crystal models.

        Checks whether the most recently added crystal model is similar to previously
        found crystal models, and if so, deletes the last crystal model from the
        experiment list.
        """
        have_similar_crystal_models = False
        cryst_b = experiments.crystals()[-1]
        for i_a, cryst_a in enumerate(experiments.crystals()[:-1]):
            R_ab, axis, angle, cb_op_ab = difference_rotation_matrix_axis_angle(
                cryst_a, cryst_b)
            min_angle = self.params.multiple_lattice_search.minimum_angular_separation
            if abs(angle) < min_angle:  # degrees
                logger.info(
                    "Crystal models too similar, rejecting crystal %i:" %
                    (len(experiments)))
                logger.info(
                    "Rotation matrix to transform crystal %i to crystal %i" %
                    (i_a + 1, len(experiments)))
                logger.info(R_ab)
                logger.info("Rotation of %.3f degrees" % angle +
                            " about axis (%.3f, %.3f, %.3f)" % axis)
                have_similar_crystal_models = True
                del experiments[-1]
                break
        return have_similar_crystal_models

    def _xyzcal_mm_to_px(self, experiments, reflections):
        # set xyzcal.px field in reflections
        reflections["xyzcal.px"] = flex.vec3_double(len(reflections))
        for i, expt in enumerate(experiments):
            imgset_sel = reflections["imageset_id"] == i
            refined_reflections = reflections.select(imgset_sel)
            panel_numbers = flex.size_t(refined_reflections["panel"])
            xyzcal_mm = refined_reflections["xyzcal.mm"]
            x_mm, y_mm, z_rad = xyzcal_mm.parts()
            xy_cal_mm = flex.vec2_double(x_mm, y_mm)
            xy_cal_px = flex.vec2_double(len(xy_cal_mm))
            for i_panel in range(len(expt.detector)):
                panel = expt.detector[i_panel]
                sel = panel_numbers == i_panel
                xy_cal_px.set_selected(
                    sel, panel.millimeter_to_pixel(xy_cal_mm.select(sel)))
            x_px, y_px = xy_cal_px.parts()
            if expt.scan is not None:
                z_px = expt.scan.get_array_index_from_angle(z_rad, deg=False)
            else:
                # must be a still image, z centroid not meaningful
                z_px = z_rad
            xyzcal_px = flex.vec3_double(x_px, y_px, z_px)
            reflections["xyzcal.px"].set_selected(imgset_sel, xyzcal_px)

    def show_experiments(self, experiments, reflections, d_min=None):
        if d_min is not None:
            reciprocal_lattice_points = reflections["rlp"]
            d_spacings = 1 / reciprocal_lattice_points.norms()
            reflections = reflections.select(d_spacings > d_min)
        for i_expt, expt in enumerate(experiments):
            logger.info("model %i (%i reflections):" %
                        (i_expt + 1,
                         (reflections["id"] == i_expt).count(True)))
            logger.info(expt.crystal)

        indexed_flags = reflections.get_flags(reflections.flags.indexed)
        imageset_id = reflections["imageset_id"]
        rows = [["Imageset", "# indexed", "# unindexed", "% indexed"]]
        for i in range(flex.max(imageset_id) + 1):
            imageset_indexed_flags = indexed_flags.select(imageset_id == i)
            indexed_count = imageset_indexed_flags.count(True)
            unindexed_count = imageset_indexed_flags.count(False)
            rows.append([
                str(i),
                str(indexed_count),
                str(unindexed_count),
                "{:.1%}".format(indexed_count /
                                (indexed_count + unindexed_count)),
            ])
        logger.info(dials.util.tabulate(rows, headers="firstrow"))

    def find_max_cell(self):
        params = self.params.max_cell_estimation
        if self.params.max_cell is libtbx.Auto:
            if self.params.known_symmetry.unit_cell is not None:
                uc_params = (self._symmetry_handler.target_symmetry_primitive.
                             unit_cell().parameters())
                self.params.max_cell = params.multiplier * max(uc_params[:3])
                logger.info("Using max_cell: %.1f Angstrom" %
                            (self.params.max_cell))
            else:
                self.params.max_cell = find_max_cell(
                    self.reflections,
                    max_cell_multiplier=params.multiplier,
                    step_size=params.step_size,
                    nearest_neighbor_percentile=params.
                    nearest_neighbor_percentile,
                    histogram_binning=params.histogram_binning,
                    nn_per_bin=params.nn_per_bin,
                    max_height_fraction=params.max_height_fraction,
                    filter_ice=params.filter_ice,
                    filter_overlaps=params.filter_overlaps,
                    overlaps_border=params.overlaps_border,
                ).max_cell
                logger.info("Found max_cell: %.1f Angstrom" %
                            (self.params.max_cell))

    def index_reflections(self, experiments, reflections):
        self._assign_indices(reflections, experiments, d_min=self.d_min)
        if self.hkl_offset is not None and self.hkl_offset != (0, 0, 0):
            reflections["miller_index"] = apply_hkl_offset(
                reflections["miller_index"], self.hkl_offset)
            self.hkl_offset = None

    def refine(self, experiments, reflections):
        from dials.algorithms.indexing.refinement import refine

        refiner, refined, outliers = refine(self.all_params, reflections,
                                            experiments)
        if outliers is not None:
            reflections["id"].set_selected(outliers, -1)
        predicted = refiner.predict_for_indexed()
        reflections["xyzcal.mm"] = predicted["xyzcal.mm"]
        reflections["entering"] = predicted["entering"]
        reflections.unset_flags(flex.bool(len(reflections), True),
                                reflections.flags.centroid_outlier)
        assert (reflections.get_flags(
            reflections.flags.centroid_outlier).count(True) == 0)
        reflections.set_flags(
            predicted.get_flags(predicted.flags.centroid_outlier),
            reflections.flags.centroid_outlier,
        )
        reflections.set_flags(
            refiner.selection_used_for_refinement(),
            reflections.flags.used_in_refinement,
        )
        return refiner.get_experiments(), reflections

    def _debug_write_reciprocal_lattice_points_as_pdb(
            self, file_name="reciprocal_lattice.pdb"):
        from cctbx import crystal, xray

        cs = crystal.symmetry(unit_cell=(1000, 1000, 1000, 90, 90, 90),
                              space_group="P1")
        for i_panel in range(len(self.experiments[0].detector)):
            if len(self.experiments[0].detector) > 1:
                file_name = "reciprocal_lattice_%i.pdb" % i_panel
            with open(file_name, "wb") as f:
                xs = xray.structure(crystal_symmetry=cs)
                reflections = self.reflections.select(
                    self.reflections["panel"] == i_panel)
                for site in reflections["rlp"]:
                    xs.add_scatterer(xray.scatterer("C", site=site))
                xs.sites_mod_short()
                f.write(xs.as_pdb_file())

    def export_as_json(self,
                       experiments,
                       file_name="indexed_experiments.json",
                       compact=False):
        assert experiments.is_consistent()
        experiments.as_file(file_name)

    def export_reflections(self, reflections, file_name="reflections.pickle"):
        reflections.as_file(file_name)

    def find_lattices(self):
        raise NotImplementedError()