Exemple #1
0
 def save_in_batches(experiments,
                     reflections,
                     exp_name,
                     refl_name,
                     batch_size=1000):
     for i, indices in enumerate(
             splitit(list(range(len(experiments))),
                     (len(experiments) // batch_size) + 1)):
         batch_expts = ExperimentList()
         batch_refls = flex.reflection_table()
         if reflections.experiment_identifiers().keys():
             for sub_idx in indices:
                 batch_expts.append(experiments[sub_idx])
             batch_refls = reflections.select(batch_expts)
             batch_refls.reset_ids()
         else:
             for sub_id, sub_idx in enumerate(indices):
                 batch_expts.append(experiments[sub_idx])
                 sub_refls = reflections.select(
                     reflections["id"] == sub_idx)
                 sub_refls["id"] = flex.int(len(sub_refls), sub_id)
                 batch_refls.extend(sub_refls)
         exp_filename = os.path.splitext(exp_name)[0] + "_%03d.expt" % i
         ref_filename = os.path.splitext(
             refl_name)[0] + "_%03d.refl" % i
         self._save_output(batch_expts, batch_refls, exp_filename,
                           ref_filename)
        def save_in_batches(
            experiments, reflections, exp_name, refl_name, batch_size=1000
        ):
            from dxtbx.command_line.image_average import splitit

            for i, indices in enumerate(
                splitit(
                    list(range(len(experiments))), (len(experiments) // batch_size) + 1
                )
            ):
                batch_expts = ExperimentList()
                batch_refls = flex.reflection_table()
                for sub_id, sub_idx in enumerate(indices):
                    batch_expts.append(experiments[sub_idx])
                    sub_refls = reflections.select(reflections["id"] == sub_idx)
                    sub_refls["id"] = flex.int(len(sub_refls), sub_id)
                    batch_refls.extend(sub_refls)
                exp_filename = os.path.splitext(exp_name)[0] + "_%03d.expt" % i
                ref_filename = os.path.splitext(refl_name)[0] + "_%03d.refl" % i
                self._save_output(batch_expts, batch_refls, exp_filename, ref_filename)
Exemple #3
0
    def run(self):
        '''Execute the script.'''
        from dials.util import log
        from time import time
        from libtbx import easy_mp
        import copy

        # Parse the command line
        params, options, all_paths = self.parser.parse_args(
            show_diff_phil=False, return_unhandled=True, quick_parse=True)

        # Check we have some filenames
        if not all_paths:
            self.parser.print_help()
            return

        # Mask validation
        for mask_path in params.spotfinder.lookup.mask, params.integration.lookup.mask:
            if mask_path is not None and not os.path.isfile(mask_path):
                raise Sorry("Mask %s not found" % mask_path)

        # Save the options
        self.options = options
        self.params = params

        st = time()

        # Configure logging
        log.config(params.verbosity,
                   info='dials.process.log',
                   debug='dials.process.debug.log')

        # Log the diff phil
        diff_phil = self.parser.diff_phil.as_str()
        if diff_phil is not '':
            logger.info('The following parameters have been modified:\n')
            logger.info(diff_phil)

        for abs_params in self.params.integration.absorption_correction:
            if abs_params.apply:
                if not (self.params.integration.debug.output
                        and not self.params.integration.debug.separate_files):
                    raise Sorry('Shoeboxes must be saved to integration intermediates to apply an absorption correction. '\
                      +'Set integration.debug.output=True, integration.debug.separate_files=False and '\
                      +'integration.debug.delete_shoeboxes=True to temporarily store shoeboxes.')

        self.load_reference_geometry()
        from dials.command_line.dials_import import ManualGeometryUpdater
        update_geometry = ManualGeometryUpdater(params)

        # Import stuff
        logger.info("Loading files...")
        pre_import = params.dispatch.pre_import or len(all_paths) == 1
        if pre_import:
            # Handle still imagesets by breaking them apart into multiple datablocks
            # Further handle single file still imagesets (like HDF5) by tagging each
            # frame using its index

            datablocks = [do_import(path) for path in all_paths]

            indices = []
            basenames = []
            split_datablocks = []
            for datablock in datablocks:
                for imageset in datablock.extract_imagesets():
                    paths = imageset.paths()
                    for i in xrange(len(imageset)):
                        subset = imageset[i:i + 1]
                        split_datablocks.append(
                            DataBlockFactory.from_imageset(subset)[0])
                        indices.append(i)
                        basenames.append(
                            os.path.splitext(os.path.basename(paths[i]))[0])
            tags = []
            for i, basename in zip(indices, basenames):
                if basenames.count(basename) > 1:
                    tags.append("%s_%05d" % (basename, i))
                else:
                    tags.append(basename)

            # Wrapper function
            def do_work(i, item_list):
                processor = Processor(copy.deepcopy(params),
                                      composite_tag="%04d" % i)

                for item in item_list:
                    try:
                        for imageset in item[1].extract_imagesets():
                            update_geometry(imageset)
                    except RuntimeError as e:
                        logger.warning(
                            "Error updating geometry on item %s, %s" %
                            (str(item[0]), str(e)))
                        continue

                    if self.reference_detector is not None:
                        from dxtbx.model import Detector
                        for i in range(len(imageset)):
                            imageset.set_detector(Detector.from_dict(
                                self.reference_detector.to_dict()),
                                                  index=i)

                    processor.process_datablock(item[0], item[1])
                processor.finalize()

            iterable = zip(tags, split_datablocks)

        else:
            basenames = [
                os.path.splitext(os.path.basename(filename))[0]
                for filename in all_paths
            ]
            tags = []
            for i, basename in enumerate(basenames):
                if basenames.count(basename) > 1:
                    tags.append("%s_%05d" % (basename, i))
                else:
                    tags.append(basename)

            # Wrapper function
            def do_work(i, item_list):
                processor = Processor(copy.deepcopy(params),
                                      composite_tag="%04d" % i)
                for item in item_list:
                    tag, filename = item

                    datablock = do_import(filename)
                    imagesets = datablock.extract_imagesets()
                    if len(imagesets) == 0 or len(imagesets[0]) == 0:
                        logger.info("Zero length imageset in file: %s" %
                                    filename)
                        return
                    if len(imagesets) > 1:
                        raise Abort(
                            "Found more than one imageset in file: %s" %
                            filename)
                    if len(imagesets[0]) > 1:
                        raise Abort(
                            "Found a multi-image file. Run again with pre_import=True"
                        )

                    try:
                        update_geometry(imagesets[0])
                    except RuntimeError as e:
                        logger.warning(
                            "Error updating geometry on item %s, %s" %
                            (tag, str(e)))
                        continue

                    if self.reference_detector is not None:
                        from dxtbx.model import Detector
                        imagesets[0].set_detector(
                            Detector.from_dict(
                                self.reference_detector.to_dict()))

                    processor.process_datablock(tag, datablock)
                processor.finalize()

            iterable = zip(tags, all_paths)

        # Process the data
        if params.mp.method == 'mpi':
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            rank = comm.Get_rank(
            )  # each process in MPI has a unique id, 0-indexed
            size = comm.Get_size(
            )  # size: number of processes running in this job

            subset = [
                item for i, item in enumerate(iterable)
                if (i + rank) % size == 0
            ]
            do_work(rank, subset)
        else:
            from dxtbx.command_line.image_average import splitit
            if params.mp.nproc == 1:
                do_work(0, iterable)
            else:
                result = list(
                    easy_mp.multi_core_run(
                        myfunction=do_work,
                        argstuples=list(
                            enumerate(splitit(iterable, params.mp.nproc))),
                        nproc=params.mp.nproc))
                error_list = [r[2] for r in result]
                if error_list.count(None) != len(error_list):
                    print(
                        "Some processes failed excecution. Not all images may have processed. Error messages:"
                    )
                    for error in error_list:
                        if error is None: continue
                        print(error)

        # Total Time
        logger.info("")
        logger.info("Total Time Taken = %f seconds" % (time() - st))
def run(args):
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()  # each process in MPI has a unique id, 0-indexed
    size = comm.Get_size()  # size: number of processes running in this job

    if "-h" in args or "--help" in args:
        if rank == 0:
            print(help_str)
        return

    if rank == 0:
        from dxtbx.command_line.image_average import splitit
        filenames = []
        for arg in sys.argv[1:]:
            filenames.extend(glob.glob(arg))
        if not filenames:
            sys.exit("No data found")
        filenames = splitit(filenames, size)
    else:
        filenames = None

    filenames = comm.scatter(filenames, root=0)

    x, y = flex.double(), flex.double()
    det = None
    for fn in filenames:
        print(fn)
        try:
            refls = flex.reflection_table.from_file(
                fn.split('_strong.expt')[0] + "_strong.refl")
        except OSError:
            continue
        expts = ExperimentList.from_file(fn, check_format=False)
        for expt_id, expt in enumerate(expts):
            subset = refls.select(expt_id == refls['id'])
            if len(subset) > 200: continue
            det = expt.detector
            for panel_id, panel in enumerate(det):
                r = subset.select(subset['panel'] == panel_id)
                x_, y_, _ = r['xyzobs.px.value'].parts()
                pix = panel.pixel_to_millimeter(flex.vec2_double(x_, y_))
                c = panel.get_lab_coord(pix)
                x.extend(c.parts()[0])
                y.extend(c.parts()[1])

    if det:
        z = flex.double(len(x),
                        sum([p.get_origin()[2] for p in det]) / len(det))
        coords = flex.vec3_double(x, y, z)
        two_theta = coords.angle((0, 0, -1))
        d = expts[0].beam.get_wavelength() / 2 / flex.sin(two_theta / 2)
        azi = flex.vec3_double(x, y, flex.double(len(x), 0)).angle((0, 1, 0),
                                                                   deg=True)
        azi.set_selected(x < 0, 180 + (180 - azi.select(x < 0)))
    else:
        d = flex.double()
        azi = flex.double()

    if rank == 0:

        def saveit():
            np.save(f, x.as_numpy_array())
            np.save(f, y.as_numpy_array())
            np.save(f, d.as_numpy_array())
            np.save(f, azi.as_numpy_array())

        import numpy as np
        with open('cake.npy', 'wb') as f:
            saveit()
            for i in range(1, size):
                print('waiting for', i)
                x, y, d, azi = comm.recv(source=i)
                saveit()
    else:
        print('rank', rank, 'sending')
        comm.send((x, y, d, azi), dest=0)