Ejemplo n.º 1
0
    def run(self):
        MAIN_LOGGER = logging.getLogger("diffBragg.main")
        assert os.path.exists(self.params.exp_ref_spec_file)
        input_lines = None
        best_models = None
        if COMM.rank == 0:
            input_lines = open(self.params.exp_ref_spec_file, "r").readlines()
            if self.params.skip is not None:
                input_lines = input_lines[self.params.skip:]
            if self.params.first_n is not None:
                input_lines = input_lines[:self.params.first_n]
            if self.params.sanity_test_input:
                hopper_utils.sanity_test_input_lines(input_lines)

            if self.params.best_pickle is not None:
                logging.info("reading pickle %s" % self.params.best_pickle)
                best_models = pandas.read_pickle(self.params.best_pickle)

            if self.params.dump_gathers:
                if self.params.gathers_dir is None:
                    raise ValueError("Need to provide a file dir path in order to dump_gathers")
                utils.safe_makedirs(self.params.gathers_dir)
        input_lines = COMM.bcast(input_lines)
        best_models = COMM.bcast(best_models)

        if self.params.ignore_existing:
            exp_names_already =None
            if COMM.rank==0:
                exp_names_already = {os.path.basename(f) for f in glob.glob("%s/expers/rank*/*.expt" % self.params.outdir)}
            exp_names_already = COMM.bcast(exp_names_already)

        exp_gatheredRef_spec = []  # optional list of expt, refls, spectra
        for i_exp, line in enumerate(input_lines):
            if i_exp == self.params.max_process:
                break
            if i_exp % COMM.size != COMM.rank:
                continue

            logging.info("COMM.rank %d on shot  %d / %d" % (COMM.rank, i_exp + 1, len(input_lines)))
            line_fields = line.strip().split()
            assert len(line_fields) in [2, 3]
            if len(line_fields) == 2:
                exp, ref = line_fields
                spec = None
            else:
                exp, ref, spec = line_fields

            if self.params.ignore_existing:
                basename = os.path.splitext(os.path.basename(exp))[0]
                opt_exp = "%s_%s_%d.expt" % (self.params.tag, basename, i_exp)
                if opt_exp in exp_names_already:
                    continue

            best = None
            if best_models is not None:
                best = best_models.query("exp_name=='%s'" % exp)
                if len(best) == 0:
                    best = best_models.query("opt_exp_name=='%s'" % exp)
                if len(best) != 1:
                    raise ValueError("Should be 1 entry for exp %s in best pickle %s" % (exp, self.params.best_pickle))
            self.params.simulator.spectrum.filename = spec
            MAIN_LOGGER.info("Modeling %s" % exp)
            Modeler = hopper_utils.DataModeler(self.params)
            if self.params.load_data_from_refls:
                gathered = Modeler.GatherFromReflectionTable(exp, ref)
            else:
                gathered = Modeler.GatherFromExperiment(exp, ref)
            if not gathered:
                logging.warning("No refls in %s; CONTINUE; COMM.rank=%d" % (ref, COMM.rank))
                continue
            if self.params.dump_gathers:
                output_name = os.path.splitext(os.path.basename(exp))[0]
                output_name += "_withData.refl"
                output_name = os.path.join(self.params.gathers_dir, output_name)
                Modeler.dump_gathered_to_refl(output_name, do_xyobs_sanity_check=True)  # NOTE do this is modelin strong spots only
                if self.params.test_gathered_file:
                    all_data = Modeler.all_data.copy()
                    all_roi_id = Modeler.roi_id.copy()
                    all_bg = Modeler.all_background.copy()
                    all_trusted = Modeler.all_trusted.copy()
                    all_pids = np.array(Modeler.pids)
                    all_rois = np.array(Modeler.rois)
                    new_Modeler = hopper_utils.DataModeler(self.params)
                    assert new_Modeler.GatherFromReflectionTable(exp, output_name)
                    assert np.allclose(new_Modeler.all_data, all_data)
                    assert np.allclose(new_Modeler.all_background, all_bg)
                    assert np.allclose(new_Modeler.rois, all_rois)
                    assert np.allclose(new_Modeler.pids, all_pids)
                    assert np.allclose(new_Modeler.all_trusted, all_trusted)
                    assert np.allclose(new_Modeler.roi_id, all_roi_id)

                exp_gatheredRef_spec.append((exp, os.path.abspath(output_name), spec))
                if self.params.only_dump_gathers:
                    continue

            if self.params.refiner.reference_geom is not None:
                detector = ExperimentListFactory.from_json_file(self.params.refiner.reference_geom, check_format=False)[0].detector
                Modeler.E.detector = detector
            Modeler.SimulatorFromExperiment(best)
            Modeler.SIM.D.store_ave_wavelength_image = True
            if self.params.refiner.verbose is not None and COMM.rank==0:
                Modeler.SIM.D.verbose = self.params.refiner.verbose
            if self.params.profile:
                Modeler.SIM.record_timings = True
            if self.params.use_float32:
                Modeler.all_data = Modeler.all_data.astype(np.float32)
                Modeler.all_background = Modeler.all_background.astype(np.float32)

            if self.params.refiner.randomize_devices:
                dev = np.random.choice(self.params.refiner.num_devices)
                logging.info("Rank %d will use randomly chosen device %d on host %s" % (COMM.rank, dev, socket.gethostname()))
            else:
                dev = COMM.rank % self.params.refiner.num_devices
                logging.info("Rank %d will use device %d on host %s" % (COMM.rank, dev, socket.gethostname()))

            Modeler.SIM.D.device_Id = dev

            nparam = len(Modeler.SIM.P)
            x0 = [1] * nparam
            x = Modeler.Minimize(x0)
            if self.params.profile:
                Modeler.SIM.D.show_timings(COMM.rank) #, out)
            save_up(Modeler, x, exp, i_exp, ref)

        if self.params.dump_gathers and self.params.gathered_output_file is not None:
            exp_gatheredRef_spec = COMM.reduce(exp_gatheredRef_spec)
            if COMM.rank == 0:
                o = open(self.params.gathered_output_file, "w")
                for e, r, s in exp_gatheredRef_spec:
                    if s is not None:
                        o.write("%s %s %s\n" % (e,r,s))
                    else:
                        o.write("%s %s\n" % (e,r))
                o.close()
Ejemplo n.º 2
0
def get_errors(phil_file,
               expt_name,
               refl_name,
               pkl_name,
               outfile_prefix=None,
               verbose=False,
               devid=0):
    """

    :param phil_file:
    :param expt_name:
    :param refl_name:
    :param pkl_name:
    :param outfile_prefix:
    :param verbose:
    :return:
    """
    params = utils.get_extracted_params_from_phil_sources(phil_file)
    Mod = hopper_utils.DataModeler(params)
    if not Mod.GatherFromExperiment(expt_name, refl_name):
        return
    df = pandas.read_pickle(pkl_name)
    Mod.SimulatorFromExperiment(df)
    if params.spectrum_from_imageset:
        data_expt = load_expt_from_df(df)
        spec = hopper_utils.downsamp_spec_from_params(params, data_expt)
    elif df.spectrum_filename.values[0] is not None:
        spec = utils.load_spectra_from_dataframe(df)
    else:
        data_expt = load_expt_from_df(df)
        spec = [(data_expt.beam.get_wavelength(), df.total_flux.values[0])]
    Mod.SIM.beam.spectrum = spec
    Mod.SIM.D.xray_beams = Mod.SIM.beam.xray_beams
    Mod.SIM.D.device_Id = devid
    target = hopper_utils.TargetFunc(Mod.SIM)
    # set up the refinement flags
    num_param = len(Mod.SIM.P)
    x0 = np.ones(num_param)

    vary = np.ones(num_param, bool)
    for p in Mod.SIM.P.values():
        if not p.refine:
            vary[p.xpos] = False

    target.vary = vary  # fixed flags
    target.x0 = np.array(x0, np.float64)

    if Mod.SIM.P["RotXYZ0_xtal0"].refine:
        Mod.SIM.D.refine(hopper_utils.ROTX_ID)
        Mod.SIM.D.refine(hopper_utils.ROTY_ID)
        Mod.SIM.D.refine(hopper_utils.ROTZ_ID)
    if Mod.SIM.P["Nabc0"].refine:
        Mod.SIM.D.refine(hopper_utils.NCELLS_ID)
    if Mod.SIM.P["ucell0"].refine:
        for i_ucell in range(len(Mod.SIM.ucell_man.variables)):
            Mod.SIM.D.refine(hopper_utils.UCELL_ID_OFFSET + i_ucell)
    if Mod.SIM.P["eta_abc0"].refine:
        Mod.SIM.D.refine(hopper_utils.ETA_ID)
    if Mod.SIM.P["detz_shift"].refine:
        Mod.SIM.D.refine(hopper_utils.DETZ_ID)
    if Mod.SIM.D.use_diffuse:
        Mod.SIM.D.refine(hopper_utils.DIFFUSE_ID)

    model_bragg, Jac = hopper_utils.model(x0,
                                          Mod.SIM,
                                          Mod.pan_fast_slow,
                                          compute_grad=True,
                                          dont_rescale_gradient=True)
    model_pix = model_bragg + Mod.all_background

    u = Mod.all_data - model_pix  # residuals, named "u" in notes

    sigma_rdout = params.refiner.sigma_r / params.refiner.adu_per_photon
    v = model_pix + sigma_rdout**2
    one_by_v = 1 / v
    G = 1 - 2 * u - u * u * one_by_v
    coef = one_by_v * (one_by_v * G - 2 - 2 * u * one_by_v -
                       u * u * one_by_v * one_by_v)

    coef_t = coef[Mod.all_trusted]
    Jac_t = Jac[:, Mod.all_trusted]
    # if we are only optimizing Fhkl, then the Hess is diagonal matrix
    diag_Hess = -.5 * np.sum(coef_t * (Jac_t)**2, axis=1)
    with np.errstate(divide='ignore', invalid='ignore'):
        variance_s = 1 / diag_Hess

    ## if we optimized per-shot scale along with Fhkl scales, then the Hess is an arrow matrix (diagonal with elem in first row/col)
    #name_to_i_Hess = {}
    #name_to_i_Hess["G_xtal0"] = 0
    #i_Hess = 1
    #for name in Mod.SIM.P:
    #    if name.startswith("scale_roi"):
    #        name_to_i_Hess[name] = i_Hess
    #        i_Hess += 1
    #Hess = np.zeros((len(name_to_i_Hess), len(name_to_i_Hess)))
    #scale_p = Mod.SIM.P["G_xtal0"]
    #overall_scale = scale_p.get_val(x0[scale_p.xpos])
    #name_from_i_Hess = {i:name for name,i in name_to_i_Hess.items()}

    #for name in name_to_i_Hess:
    #    p = Mod.SIM.P[name]
    #    xpos = p.xpos
    #    i_Hess = name_to_i_Hess[name]
    #    val = diag_Hess[xpos]
    #    Hess[i_Hess, i_Hess] = val

    ## offdiagonal terms
    #jac_coef_t = (.5*one_by_v*G)[Mod.all_trusted]
    #for i_Hess in range(1, len(name_to_i_Hess)):
    #    name = name_from_i_Hess[i_Hess]
    #    p = Mod.SIM.P[name]

    #    val_off_diag = jac_coef_t*Jac_t[p.xpos]
    #    val_off_diag = val_off_diag.sum() / overall_scale

    #    Hess[0, i_Hess] = val_off_diag
    #    Hess[i_Hess, 0] = val_off_diag

    F = Mod.SIM.crystal.miller_array
    Fmap = {h: amp for h, amp in zip(F.indices(), F.data())}
    all_I = []
    all_s = []
    all_varI = []
    #assert len(Mod.roi_id_unique) == len(Mod.refls)
    flex_varI = flex.double(len(Mod.refls), 0)
    flex_I = flex.double(len(Mod.refls), 0)
    sel = flex.bool(len(Mod.refls), False)

    Mod.set_slices("all_refls_idx")
    #for roi_id in Mod.roi_id_unique:
    for refl_idx in Mod.all_refls_idx_unique:
        refl_idx = int(refl_idx)
        data_slc = Mod.all_refls_idx_slices[refl_idx]
        assert len(data_slc) == 1
        data_slc = data_slc[0]
        roi_id = int(Mod.roi_id[data_slc][0])
        p = Mod.SIM.P["scale_roi%d" % roi_id]
        # TODO : double check scale evaluated from x=1
        scale = p.get_val(1)
        var_s = variance_s[p.xpos]
        hkl = Mod.hi_asu_perpix[data_slc][0]
        if hkl not in Fmap:
            continue
        amp = Fmap[hkl]
        I_hkl = amp**2
        var_I = I_hkl**2 * var_s
        if var_I <= 1e-6 or var_I > 1e10:
            continue
        I = scale * I_hkl
        h, k, l = hkl
        if verbose:
            print("hkl=%d,%d,%d . I=%f +- %f" % (h, k, l, I, var_I))
        all_I.append(I)
        all_varI.append(var_I)
        all_s.append(scale)

        #refl_idx = int(Mod.all_refls_idx[data_slc][0])
        sel[refl_idx] = True
        flex_I[refl_idx] = I
        flex_varI[refl_idx] = var_I

        refl = Mod.refls[refl_idx]
        assert refl["scale_factor"] == scale

    Mod.refls["intensity.sum.value"] = flex_I
    Mod.refls["intensity.sum.variance"] = flex_varI
    Mod.refls["xyzobs.px.value"] = Mod.refls["xyzcal.px"]
    integ_refls = Mod.refls.select(sel)
    #all_s = np.array(all_s)
    #all_I = np.array(all_I)
    #all_varI = np.array(all_varI)
    #from IPython import embed;embed();exit()

    hopper_utils.free_SIM_mem(Mod.SIM)
    if outfile_prefix is not None:
        integ_refls.as_file(outfile_prefix + ".refl")
        copyfile(expt_name, outfile_prefix + ".expt")
    if verbose:
        print("Done.")
Ejemplo n.º 3
0
    def load_inputs(self,
                    pandas_table,
                    miller_data=None,
                    refls_key='predictions'):
        COMM.Barrier()
        num_exp = len(pandas_table)
        first_exper_file = pandas_table.exp_name.values[0]
        detector = ExperimentListFactory.from_json_file(
            first_exper_file, check_format=False)[0].detector
        if detector is None and self.params.refiner.reference_geom is None:
            raise RuntimeError(
                "No detector in experiment, must provide a reference geom.")
        # TODO verify all shots have the same detector ?
        if self.params.refiner.reference_geom is not None:
            detector = ExperimentListFactory.from_json_file(
                self.params.refiner.reference_geom,
                check_format=False)[0].detector
            print("Using reference geom from expt %s" %
                  self.params.refiner.reference_geom)

        if COMM.size > num_exp:
            raise ValueError(
                "Requested %d MPI ranks to process %d shots. Reduce number of ranks to %d"
                % (COMM.size, num_exp, num_exp))
        self._init_panel_group_information(detector)

        self.verbose = False
        if COMM.rank == 0:
            self.verbose = self.params.refiner.verbose > 0
            if self.params.refiner.gather_dir is not None and not os.path.exists(
                    self.params.refiner.gather_dir):
                os.makedirs(self.params.refiner.gather_dir)
                LOGGER.info("MADE GATHER DIR %s" %
                            self.params.refiner.gather_dir)
        COMM.barrier()
        shot_idx = 0  # each rank keeps index of the shots local to it
        rank_panel_groups_refined = set()
        exper_names = pandas_table.exp_name
        assert len(exper_names) == len(set(exper_names))
        # TODO assert all exper are single-file, probably way before this point
        LOGGER.info("EVENT: begin loading inputs")
        for i_exp, exper_name in enumerate(exper_names):
            if i_exp % COMM.size != COMM.rank:
                continue
            LOGGER.info("EVENT: BEGIN loading experiment list")
            expt_list = ExperimentListFactory.from_json_file(
                exper_name,
                check_format=not self.params.refiner.load_data_from_refl)
            LOGGER.info("EVENT: DONE loading experiment list")
            if len(expt_list) != 1:
                print("Input experiments need to have length 1, %s does not" %
                      exper_name)
            expt = expt_list[0]
            expt.detector = detector  # in case of supplied ref geom
            self._check_experiment_integrity(expt)

            exper_dataframe = pandas_table.query("exp_name=='%s'" % exper_name)

            refl_name = exper_dataframe[refls_key].values[0]
            refls = flex.reflection_table.from_file(refl_name)
            # FIXME need to remove (0,0,0) bboxes
            good_sel = flex.bool(
                [h != (0, 0, 0) for h in list(refls["miller_index"])])
            refls = refls.select(good_sel)

            #UcellMan = utils.manager_from_crystal(expt.crystal)
            opt_uc_param = exper_dataframe[["a", "b", "c", "al", "be",
                                            "ga"]].values[0]
            UcellMan = utils.manager_from_params(opt_uc_param)

            if self.symbol is None:
                if self.params.refiner.force_symbol is not None:
                    self.symbol = self.params.refiner.force_symbol
                else:
                    self.symbol = expt.crystal.get_space_group().type(
                    ).lookup_symbol()
            else:
                if self.params.refiner.force_symbol is None:
                    if expt.crystal.get_space_group().type().lookup_symbol(
                    ) != self.symbol:
                        raise ValueError(
                            "Crystals should all have the same space group symmetry"
                        )

            if shot_idx == 0:  # each rank initializes a simulator only once
                if self.params.simulator.init_scale != 1:
                    print(
                        "WARNING: For stage_two , it is assumed that total scale is stored in the pandas dataframe"
                    )
                    print(
                        "WARNING: resetting params.simulator.init_scale to 1!")
                    self.params.simulator.init_scale = 1
                self._init_simulator(expt, miller_data)
                if self.params.profile:
                    self.SIM.record_timings = True
                if self.params.refiner.stage_two.Fref_mtzname is not None:
                    self.Fref = utils.open_mtz(
                        self.params.refiner.stage_two.Fref_mtzname,
                        self.params.refiner.stage_two.Fref_mtzcol)

            LOGGER.info("EVENT: LOADING ROI DATA")
            shot_modeler = hopper_utils.DataModeler(self.params)
            if self.params.refiner.load_data_from_refl:
                gathered = shot_modeler.GatherFromReflectionTable(
                    expt, refls, sg_symbol=self.symbol)
            else:
                gathered = shot_modeler.GatherFromExperiment(
                    expt, refls, sg_symbol=self.symbol)
            if not gathered:
                raise ("Failed to gather data from experiment %s", exper_name)

            if self.params.refiner.gather_dir is not None:
                gathered_name = os.path.splitext(
                    os.path.basename(exper_name))[0]
                gathered_name += "_withData.refl"
                gathered_name = os.path.join(self.params.refiner.gather_dir,
                                             gathered_name)
                shot_modeler.dump_gathered_to_refl(
                    gathered_name, do_xyobs_sanity_check=False)  #True)
                LOGGER.info("SAVED ROI DATA TO %s" % gathered_name)
                if self.params.refiner.test_gathered_file:
                    all_data = shot_modeler.all_data.copy()
                    all_roi_id = shot_modeler.roi_id.copy()
                    all_bg = shot_modeler.all_background.copy()
                    all_trusted = shot_modeler.all_trusted.copy()
                    all_pids = np.array(shot_modeler.pids)
                    all_rois = np.array(shot_modeler.rois)
                    new_Modeler = hopper_utils.DataModeler(self.params)
                    assert new_Modeler.GatherFromReflectionTable(
                        exper_name, gathered_name, sg_symbol=self.symbol)
                    assert np.allclose(new_Modeler.all_data, all_data)
                    assert np.allclose(new_Modeler.all_background, all_bg)
                    assert np.allclose(new_Modeler.rois, all_rois)
                    assert np.allclose(new_Modeler.pids, all_pids)
                    assert np.allclose(new_Modeler.all_trusted, all_trusted)
                    assert np.allclose(new_Modeler.roi_id, all_roi_id)
                    LOGGER.info("Gathered file approved!")

            self.Hi[shot_idx] = shot_modeler.Hi
            self.Hi_asu[shot_idx] = shot_modeler.Hi_asu

            LOGGER.info("EVENT: DONE LOADING ROI")
            shot_modeler.ucell_man = UcellMan
            self.SIM.num_ucell_param = len(
                shot_modeler.ucell_man.variables)  # for convenience

            if not self.params.refiner.load_data_from_refl and self.params.spectrum_from_imageset:
                shot_spectra = hopper_utils.downsamp_spec(
                    self.SIM, self.params, expt, return_and_dont_set=True)

            elif "spectrum_filename" in list(
                    exper_dataframe
            ) and exper_dataframe.spectrum_filename.values[0] is not None:
                shot_spectra = utils.load_spectra_from_dataframe(
                    exper_dataframe)

            else:
                total_flux = exper_dataframe.total_flux.values[0]
                if total_flux is None:
                    total_flux = self.params.simulator.total_flux
                shot_spectra = [(expt.beam.get_wavelength(), total_flux)]

            shot_modeler.spectra = shot_spectra
            if self.params.refiner.gather_dir is not None and not self.params.refiner.load_data_from_refl:
                spec_wave, spec_weights = map(np.array, zip(*shot_spectra))
                spec_filename = os.path.splitext(
                    os.path.basename(exper_name))[0]
                spec_filename = os.path.join(self.params.refiner.gather_dir,
                                             spec_filename + ".lam")
                utils.save_spectra_file(spec_filename, spec_wave, spec_weights)
                LOGGER.info("saved spectra filename %s" % spec_filename)

            LOGGER.info("Will simulate %d energy channels" % len(shot_spectra))

            if "detz_shift_mm" in list(exper_dataframe):
                shot_modeler.originZ_init = exper_dataframe.detz_shift_mm.values[
                    0] * 1e-3
            else:
                shot_modeler.originZ_init = 0
            shot_modeler.exper_name = exper_name

            shot_panel_groups_refined = self.determine_refined_panel_groups(
                shot_modeler.pids)
            rank_panel_groups_refined = rank_panel_groups_refined.union(
                set(shot_panel_groups_refined))

            shot_idx += 1
            if COMM.rank == 0:
                self._mem_usage()
                print("Finished loading image %d / %d" %
                      (i_exp + 1, len(exper_names)),
                      flush=True)

            shot_modeler.PAR = PAR_from_params(self.params,
                                               expt,
                                               best=exper_dataframe)
            self.Modelers[i_exp] = shot_modeler

        LOGGER.info("DONE LOADING DATA; ENTER BARRIER")
        COMM.Barrier()
        LOGGER.info("DONE LOADING DATA; EXIT BARRIER")
        self.shot_roi_darkRMS = None

        # TODO warn that per_spot_scale refinement not intended for ensemble mode
        all_refined_groups = COMM.gather(rank_panel_groups_refined)
        panel_groups_refined = None
        if COMM.rank == 0:
            panel_groups_refined = set()
            for set_of_panels in all_refined_groups:
                panel_groups_refined = panel_groups_refined.union(
                    set_of_panels)
        self.panel_groups_refined = list(COMM.bcast(panel_groups_refined))

        LOGGER.info("EVENT: Gathering global HKL information")
        self._gather_Hi_information()
        LOGGER.info("EVENT: FINISHED gather global HKL information")
        if self.params.roi.cache_dir_only:
            print(
                "Done creating cache directory and cache_dir_only=True, so goodbye."
            )
            sys.exit()

        # in case of GPU
        LOGGER.info("BEGIN DETERMINE MAX PIX")
        self.NPIX_TO_ALLOC = self._determine_per_rank_max_num_pix()
        # TODO in case of randomize devices, shouldnt this be total max across all ranks?
        n = COMM.gather(self.NPIX_TO_ALLOC)
        if COMM.rank == 0:
            n = max(n)
        self.NPIX_TO_ALLOC = COMM.bcast(n)
        LOGGER.info("DONE DETERMINE MAX PIX")

        self.DEVICE_ID = COMM.rank % self.params.refiner.num_devices
        self._mem_usage()
    def load_inputs(self,
                    pandas_table,
                    miller_data=None,
                    work_distribution=None,
                    refls_key='predictions'):
        """

        :param pandas_table: contains path to the experiments (pandas column exp_name) to be loaded
            the pandas table is expected to have been written by diffBragg.hopper or
            diffBragg.hopper_process . See method save_to_pandas in simtbx/command_line/hopper.py
            For example, if the outputdir of diffBragg.hopper was set to `all_shots`, then
            there should be a golder all_shots/pandas created which contains all of the per-shot pandas
            dataframes. They should be concatenated as follows, forming a suitable argument for this method
            >> import glob,pandas
            >> fnames = glob.glob("all_shots/pandas/rank*/*pkl")
            >> df = pandas.concat([ pandas.read_pickle(f) for f in fnames])
            >> df.reset_index(inplace=True, drop=True)
            >> df.to_pickle("all_shots.pkl")
            >> # Then later, as part of an MPI application, the following will load all data:
            >> RefineLauncher_instance.load_inputs(df, refls_key="stage1_refls")

        :param miller_data: Optional miller array for the structure factor component of the model
        :param refls_key: key specifying the reflection tables in the pandas table
            Modeled pixels will lie in shoeboxes centered on each x,y,z in xyzobs.px.value
        :return:
        """
        COMM.Barrier()
        num_exp = len(pandas_table)
        first_exper_file = pandas_table.exp_name.values[0]
        detector = ExperimentListFactory.from_json_file(
            first_exper_file, check_format=False)[0].detector
        if detector is None and self.params.refiner.reference_geom is None:
            raise RuntimeError(
                "No detector in experiment, must provide a reference geom.")
        # TODO verify all shots have the same detector ?
        if self.params.refiner.reference_geom is not None:
            detector = ExperimentListFactory.from_json_file(
                self.params.refiner.reference_geom,
                check_format=False)[0].detector
            print("Using reference geom from expt %s" %
                  self.params.refiner.reference_geom)

        if COMM.size > num_exp:
            raise ValueError(
                "Requested %d MPI ranks to process %d shots. Reduce number of ranks to %d"
                % (COMM.size, num_exp, num_exp))
        self._init_panel_group_information(detector)

        self.verbose = False
        if COMM.rank == 0:
            self.verbose = self.params.refiner.verbose > 0
            if self.params.refiner.gather_dir is not None and not os.path.exists(
                    self.params.refiner.gather_dir):
                os.makedirs(self.params.refiner.gather_dir)
                LOGGER.info("MADE GATHER DIR %s" %
                            self.params.refiner.gather_dir)
        COMM.barrier()
        shot_idx = 0  # each rank keeps index of the shots local to it
        rank_panel_groups_refined = set()
        exper_names = pandas_table.exp_name
        assert len(exper_names) == len(set(exper_names))
        # TODO assert all exper are single-file, probably way before this point
        if work_distribution is None:
            worklist = range(COMM.rank, nshots, COMM.size)
        else:
            worklist = work_distribution[COMM.rank]
        LOGGER.info("EVENT: begin loading inputs")
        for i_exp in worklist:
            exper_name = exper_names[i_exp]
            LOGGER.info("EVENT: BEGIN loading experiment list")
            expt_list = ExperimentListFactory.from_json_file(
                exper_name, check_format=self.params.refiner.check_expt_format)
            LOGGER.info("EVENT: DONE loading experiment list")
            if len(expt_list) != 1:
                print("Input experiments need to have length 1, %s does not" %
                      exper_name)
            expt = expt_list[0]
            expt.detector = detector  # in case of supplied ref geom
            self._check_experiment_integrity(expt)

            exper_dataframe = pandas_table.query("exp_name=='%s'" % exper_name)

            refl_name = exper_dataframe[refls_key].values[0]
            refls = flex.reflection_table.from_file(refl_name)
            # FIXME need to remove (0,0,0) bboxes

            try:
                good_sel = flex.bool(
                    [h != (0, 0, 0) for h in list(refls["miller_index"])])
                refls = refls.select(good_sel)
            except KeyError:
                pass

            #UcellMan = utils.manager_from_crystal(expt.crystal)
            opt_uc_param = exper_dataframe[["a", "b", "c", "al", "be",
                                            "ga"]].values[0]
            UcellMan = utils.manager_from_params(opt_uc_param)

            if self.symbol is None:
                if self.params.refiner.force_symbol is not None:
                    self.symbol = self.params.refiner.force_symbol
                else:
                    self.symbol = expt.crystal.get_space_group().type(
                    ).lookup_symbol()
                LOGGER.info("Set space group symbol: %s" % self.symbol)
            else:
                if self.params.refiner.force_symbol is None:
                    if expt.crystal.get_space_group().type().lookup_symbol(
                    ) != self.symbol:
                        raise ValueError(
                            "Crystals should all have the same space group symmetry"
                        )

            if shot_idx == 0:  # each rank initializes a simulator only once
                if self.params.simulator.init_scale != 1:
                    print(
                        "WARNING: For stage_two , it is assumed that total scale is stored in the pandas dataframe"
                    )
                    print(
                        "WARNING: resetting params.simulator.init_scale to 1!")
                    self.params.simulator.init_scale = 1
                self._init_simulator(expt, miller_data)
                if self.params.profile:
                    self.SIM.record_timings = True
                if self.params.refiner.stage_two.Fref_mtzname is not None:
                    self.Fref = utils.open_mtz(
                        self.params.refiner.stage_two.Fref_mtzname,
                        self.params.refiner.stage_two.Fref_mtzcol)

            if "miller_index" in list(refls.keys()):
                is_allowed = flex.bool(len(refls), True)
                allowed_hkls = set(self.SIM.crystal.miller_array.indices())
                for i_ref in range(len(refls)):
                    if refls[i_ref]['miller_index'] not in allowed_hkls:
                        is_allowed[i_ref] = False
                refls = refls.select(is_allowed)

            LOGGER.info("EVENT: LOADING ROI DATA")
            shot_modeler = hopper_utils.DataModeler(self.params)
            if self.params.refiner.load_data_from_refl:
                gathered = shot_modeler.GatherFromReflectionTable(
                    expt, refls, sg_symbol=self.symbol)
            else:
                gathered = shot_modeler.GatherFromExperiment(
                    expt, refls, sg_symbol=self.symbol)
            if not gathered:
                raise IOError("Failed to gather data from experiment %s",
                              exper_name)

            if self.params.refiner.gather_dir is not None:
                gathered_name = os.path.splitext(
                    os.path.basename(exper_name))[0]
                gathered_name += "_withData.refl"
                gathered_name = os.path.join(self.params.refiner.gather_dir,
                                             gathered_name)
                shot_modeler.dump_gathered_to_refl(
                    gathered_name, do_xyobs_sanity_check=False)  #True)
                LOGGER.info("SAVED ROI DATA TO %s" % gathered_name)
                if self.params.refiner.test_gathered_file:
                    all_data = shot_modeler.all_data.copy()
                    all_roi_id = shot_modeler.roi_id.copy()
                    all_bg = shot_modeler.all_background.copy()
                    all_trusted = shot_modeler.all_trusted.copy()
                    all_pids = np.array(shot_modeler.pids)
                    all_rois = np.array(shot_modeler.rois)
                    new_Modeler = hopper_utils.DataModeler(self.params)
                    assert new_Modeler.GatherFromReflectionTable(
                        exper_name, gathered_name, sg_symbol=self.symbol)
                    assert np.allclose(new_Modeler.all_data, all_data)
                    assert np.allclose(new_Modeler.all_background, all_bg)
                    assert np.allclose(new_Modeler.rois, all_rois)
                    assert np.allclose(new_Modeler.pids, all_pids)
                    assert np.allclose(new_Modeler.all_trusted, all_trusted)
                    assert np.allclose(new_Modeler.roi_id, all_roi_id)
                    LOGGER.info("Gathered file approved!")

            self.Hi[shot_idx] = shot_modeler.Hi
            self.Hi_asu[shot_idx] = shot_modeler.Hi_asu

            LOGGER.info("EVENT: DONE LOADING ROI")
            shot_modeler.ucell_man = UcellMan
            self.SIM.num_ucell_param = len(
                shot_modeler.ucell_man.variables)  # for convenience

            loaded_spectra = False
            if self.params.spectrum_from_imageset:
                try:
                    shot_spectra = hopper_utils.downsamp_spec(
                        self.SIM, self.params, expt, return_and_dont_set=True)
                    loaded_spectra = True
                except Exception as err:
                    LOGGER.warning(
                        "spectrum_from_imageset is set to True, however failed to load spectra: %s"
                        % err)
                    loaded_spectra = False

            if not loaded_spectra:
                if "spectrum_filename" in list(
                        exper_dataframe
                ) and exper_dataframe.spectrum_filename.values[0] is not None:
                    shot_spectra = utils.load_spectra_from_dataframe(
                        exper_dataframe)
                    LOGGER.debug("Loaded specta from %s" %
                                 exper_dataframe.spectrum_filename.values[0])
                    shot_modeler.spec_name = exper_dataframe.spectrum_filename.values[
                        0]

                else:
                    total_flux = exper_dataframe.total_flux.values[0]
                    if total_flux is None:
                        total_flux = self.params.simulator.total_flux
                    shot_spectra = [(expt.beam.get_wavelength(), total_flux)]

            shot_modeler.spectra = shot_spectra
            if self.params.refiner.gather_dir is not None and not self.params.refiner.load_data_from_refl:
                spec_wave, spec_weights = map(np.array, zip(*shot_spectra))
                spec_filename = os.path.splitext(
                    os.path.basename(exper_name))[0]
                spec_filename = os.path.join(self.params.refiner.gather_dir,
                                             spec_filename + ".lam")
                utils.save_spectra_file(spec_filename, spec_wave, spec_weights)
                LOGGER.info("saved spectra filename %s" % spec_filename)

            LOGGER.info("Will simulate %d energy channels" % len(shot_spectra))

            if "detz_shift_mm" in list(exper_dataframe):
                shot_modeler.originZ_init = exper_dataframe.detz_shift_mm.values[
                    0] * 1e-3
            else:
                shot_modeler.originZ_init = 0
            shot_modeler.exper_name = exper_name
            shot_modeler.refl_name = refl_name

            shot_panel_groups_refined = self.determine_refined_panel_groups(
                shot_modeler.pids)
            rank_panel_groups_refined = rank_panel_groups_refined.union(
                set(shot_panel_groups_refined))

            shot_idx += 1
            if COMM.rank == 0:
                self._mem_usage()
                print("Finished loading image %d / %d" %
                      (i_exp + 1, len(exper_names)),
                      flush=True)

            shot_modeler.PAR = PAR_from_params(self.params,
                                               expt,
                                               best=exper_dataframe)
            self.Modelers[i_exp] = shot_modeler

        LOGGER.info("DONE LOADING DATA; ENTER BARRIER")
        COMM.Barrier()
        LOGGER.info("DONE LOADING DATA; EXIT BARRIER")
        self.shot_roi_darkRMS = None

        # TODO warn that per_spot_scale refinement not intended for ensemble mode
        all_refined_groups = COMM.gather(rank_panel_groups_refined)
        panel_groups_refined = None
        if COMM.rank == 0:
            panel_groups_refined = set()
            for set_of_panels in all_refined_groups:
                panel_groups_refined = panel_groups_refined.union(
                    set_of_panels)
        self.panel_groups_refined = list(COMM.bcast(panel_groups_refined))

        LOGGER.info("EVENT: Gathering global HKL information")
        try:
            self._gather_Hi_information()
        except TypeError:
            pass
        LOGGER.info("EVENT: FINISHED gather global HKL information")
        if self.params.roi.cache_dir_only:
            print(
                "Done creating cache directory and cache_dir_only=True, so goodbye."
            )
            sys.exit()

        # in case of GPU
        LOGGER.info("BEGIN DETERMINE MAX PIX")
        self.NPIX_TO_ALLOC = self._determine_per_rank_max_num_pix()
        # TODO in case of randomize devices, shouldnt this be total max across all ranks?
        n = COMM.gather(self.NPIX_TO_ALLOC)
        if COMM.rank == 0:
            n = max(n)
        self.NPIX_TO_ALLOC = COMM.bcast(n)
        LOGGER.info("DONE DETERMINE MAX PIX")

        self.DEVICE_ID = COMM.rank % self.params.refiner.num_devices
        self._mem_usage()