Пример #1
0
simsData = sim_utils.sim_twocolors2(
    crystalAB, detector, iset.get_beam(0), FFdat,
    [parameters.ENERGY_LOW, parameters.ENERGY_HIGH],
    FLUXdat, pids=None, Gauss=False, oversample=8,
    Ncells_abc=(22, 22, 22), mos_dom=1, mos_spread=0.)

simsDataSum = GAIN * (np.array(simsData[0]) + np.array(simsData[1]))

refl_simA = spot_utils.refls_from_sims(simsAB[0], detector, beamA)
refl_simB = spot_utils.refls_from_sims(simsAB[1], detector, beamB)

# This only uses the beam to instatiate an imageset / datablock
# but otherwise the return value (refl_data) is indepent of the
# beam object passed
refl_data = spot_utils.refls_from_sims(simsDataSum, detector, beamA)
residA = metrics.check_indexable2(
    refl_data, refl_simA, detector, beamA, crystalAB, hkl_tol)
residB = metrics.check_indexable2(
    refl_data, refl_simB, detector, beamB, crystalAB, hkl_tol)

sg96 = sgtbx.space_group(" P 4nw 2abw")
FA = utils.open_flex('SA.pkl')
FB = utils.open_flex('SB.pkl')
HA = tuple([hkl for hkl in FA.indices()])
HB = tuple([hkl for hkl in FB.indices()])
HA_val_map = {hkl: FA.value_at_index(hkl) for hkl in HA}
HB_val_map = {hkl: FB.value_at_index(hkl) for hkl in HB}

d = {"crystalAB": crystalAB,
     "residA": residA,
     "residB": residB,
     "beamA": beamA,
Пример #2
0
def main(rank):

    device_Id = rank % ngpu

    worker_Id = node_id*ngpu + rank

    import os
    import sys
    from copy import deepcopy
    import glob
    from itertools import izip

    from scipy.spatial import distance
    import h5py
    import scipy.ndimage
    from IPython import embed
    import numpy as np
    import pandas
    from scipy.spatial import cKDTree
   
    from simtbx.nanoBragg import shapetype, nanoBragg 
    from libtbx.phil import parse 
    from scitbx.matrix import sqr
    import dxtbx
    from dxtbx.model.experiment_list import ExperimentListFactory
    from dxtbx.model.crystal import CrystalFactory
    from dials.algorithms.indexing.compare_orientation_matrices \
            import rotation_matrix_differences
    from dials.array_family import flex
    from dials.command_line.find_spots import phil_scope as find_spots_phil_scope

    from cxid9114.refine import metrics
    from cxid9114 import utils
    from cxid9114.geom import geom_utils
    from cxid9114.spots import integrate, spot_utils
    from cxid9114 import parameters
    from cxid9114.sim import sim_utils
    from cctbx import miller, sgtbx
    from cxid9114 import utils
    from cxid9114.bigsim import sim_spectra
    from cxid9114.refine.jitter_refine import make_param_list

    spot_par = find_spots_phil_scope.fetch(source=parse("")).extract()
    spot_par.spotfinder.threshold.dispersion.global_threshold = 40
    spot_par.spotfinder.threshold.dispersion.gain = GAIN
    spot_par.spotfinder.threshold.dispersion.kernel_size = [2,2]
    spot_par.spotfinder.threshold.dispersion.sigma_strong = 1 
    spot_par.spotfinder.threshold.dispersion.sigma_background = 6 
    spot_par.spotfinder.filter.min_spot_size = 3
    spot_par.spotfinder.force_2d = True

    odir = args.odir
    odirj = os.path.join(odir, "job%d" % worker_Id)
    
    #all_pkl_files = [s for sl in \
    #    [ files for _,_, files in  os.walk(odir)]\
    #        for s in sl if s.endswith("pkl")]
    
    #print "Found %d pkl files already in %s!" \
    #    % (len(all_pkl_files), odir)

    if not os.path.exists(odirj):
        os.makedirs(odirj)

    hkl_tol = .15
    run = 61 
    shot_idx = 0 
    ENERGIES = [parameters.ENERGY_LOW, parameters.ENERGY_HIGH]  # colors of the beams
    FF = [10000, None]  

    cryst_descr = {'__id__': 'crystal',
                  'real_space_a': (79, 0, 0),
                  'real_space_b': (0, 79, 0),
                  'real_space_c': (0, 0, 38),
                  'space_group_hall_symbol': '-P 4 2'}
    crystalAB = CrystalFactory.from_dict(cryst_descr)

    sfall_main = sim_spectra.load_spectra("../bigsim/test_sfall.h5")
    FFdat = [sfall_main[19], sfall_main[110]]

    FLUX = [1e11, 1e11]  # fluxes of the beams

    chanA_flux = np.random.uniform(1e11,1e12) 
    chanB_flux = np.random.uniform(1e11,1e12) 
    FLUXdat = [chanA_flux, chanB_flux]

    waveA = parameters.ENERGY_CONV / ENERGIES[0]
    waveB = parameters.ENERGY_CONV / ENERGIES[1]

    from cxid9114.bigsim.bigsim_geom import DET,BEAM

    detector = DET
    
    print("Rank %d Begin" % worker_Id)
    for i_data in range( args.num_trials):
        pklname = "%s_rank%d_data%d.pkl" % (ofile, worker_Id, i_data)
        pklname = os.path.join( odirj, pklname) 

        print("<><><><><><><")
        print("Job %d:  trial  %d / %d" % ( worker_Id, i_data+1, args.num_trials ))
        print("<><><><><><><")
        
        if (worker_Id==0 and i_data % smi_stride==0 and cuda):
            print("GPU status")
            os.system("nvidia-smi")

            print("\n\n")
            print("CPU memory usage")
            mem_usg= """ps -U dermen --no-headers -o rss | awk '{ sum+=$1} END {print int(sum/1024) "MB consumed by CPU user"}'"""
            os.system(mem_usg)

        beamA = deepcopy(BEAM)
        beamB = deepcopy(BEAM)
        beamA.set_wavelength(waveA)
        beamB.set_wavelength(waveB)

        SCALE = np.random.uniform(0.1,10)

        np.random.seed(args.seed)
        crystalAB = CrystalFactory.from_dict(cryst_descr)
        randnums = np.random.random(3)
        Rrand = random_rotation(1, randnums)
        crystalAB.set_U(Rrand.ravel())
        
        #pert = np.random.uniform(0.0001/2/np.pi, 0.0003 / 2. /np.pi)
        #print("PERT %f" % pert)
        #Rsmall = random_rotation(0.00001, randnums ) #pert)
       
       
        params_lst = make_param_list(crystalAB, DET, BEAM, 
            1, rot=0.08, cell=.0000001, eq=(1,1,0),
            min_Ncell=23, max_Ncell=24, 
            min_mos_spread=0.02, 
            max_mos_spread=0.08)
        Ctruth = params_lst[0]['crystal']
          
        print Ctruth.get_unit_cell().parameters()
        print crystalAB.get_unit_cell().parameters()
         
        init_comp = rotation_matrix_differences((Ctruth, crystalAB))
        init_rot = float(init_comp.split("\n")[-2].split()[2])

        if use_data_spec:
            print "NOT IMPLEMENTED, Using a phony 2col spectrum to simulate the data"
            data_fluxes = FLUXdat
            data_energies = [parameters.ENERGY_LOW, parameters.ENERGY_HIGH]
            data_ff = FFdat
        else:
            print "Using a phony two color spectrum to simulate the data"
            data_fluxes = FLUXdat
            data_energies = [parameters.ENERGY_LOW, parameters.ENERGY_HIGH]
            data_ff = FFdat

        print  ("Truth crystal Misorientation deviation: %f deg" % init_rot )
        if args.truth_cryst:
            print "Using truth crystal"
            dataCryst = Ctruth
        else:
            print "Not using truth crystal"
            dataCryst = crystalAB
        
        if not make_background:
            print "SIMULATING Flat-Fhkl IMAGES"
            simsAB = sim_utils.sim_twocolors2(
                crystalAB, detector, BEAM, FF,
                [parameters.ENERGY_LOW, parameters.ENERGY_HIGH],
                FLUX, pids=None, Gauss=Gauss, cuda=cuda, oversample=oversample, 
                Ncells_abc=Ncells_abc, mos_dom=mos_doms, mos_spread=mos_spread,
                exposure_s=exposure_s, beamsize_mm=beamsize_mm, device_Id=device_Id,
                boost=boost)
       
        if make_background:
            print("MAKING BACKGROUND")
            spec_file = h5py.File("../bigsim/simMe_data_run62.h5", "r") 
            ave_spec = np.mean( spec_file["hist_spec"][()], axis=0)
            data_fluxes=[ave_spec[19], ave_spec[110] ]
            data_energies = spec_file["energy_bins"][()][[19,110]]
            data_ff = [1,1] #*len(data_energies)
            only_water=True
        else:
            only_water=False

        print "SIULATING DATA IMAGE"
        print data_fluxes
        simsDataSum = sim_utils.sim_twocolors2(
            dataCryst, detector, BEAM, data_ff, 
            data_energies, 
            data_fluxes, pids=None, Gauss=Gauss, cuda=cuda,oversample=oversample,
            Ncells_abc=Ncells_abc, accumulate=True, mos_dom=mos_doms, 
            mos_spread=mos_spread, boost=boost,
            exposure_s=exposure_s, beamsize_mm=beamsize_mm,
            only_water=only_water, device_Id=device_Id)
            
        simsDataSum = SCALE*np.array(simsDataSum)
        
        if make_background:
            bg_out = h5py.File(bg_name, "w")
            bg_out.create_dataset("bigsim_d9114",data=simsDataSum[0])
            print "Background made! Saved to file %s" % bg_name
            sys.exit()
        
        if add_background:
            print("ADDING BG")
            background = h5py.File(bg_name, "r")['bigsim_d9114'][()]
            bg_scale = np.sum([39152412349.12075, 32315440627.406036] )
            bg_scale = np.sum(data_fluxes) / bg_scale
            print "%.3e backgorund scale" % bg_scale 
            print "BG shape", background.shape
            simsDataSum[0] += background * bg_scale

        if add_noise:
            print("ADDING NOISE")
            for pidx in range(1):
                SIM = nanoBragg(detector=DET, beam=BEAM, panel_id=pidx)
                SIM.exposure_s = exposure_s
                SIM.beamsize_mm = beamsize_mm
                SIM.flux = np.sum(data_fluxes)
                SIM.detector_psf_kernel_radius_pixels=5;
                #SIM.detector_psf_type=shapetype.Gauss
                SIM.detector_psf_type=shapetype.Unknown  # for CSPAD
                SIM.detector_psf_fwhm_mm=0
                SIM.quantum_gain = GAIN
                SIM.raw_pixels = flex.double(simsDataSum[pidx].ravel())
                SIM.add_noise()
                simsDataSum[pidx] = SIM.raw_pixels.as_numpy_array()\
                    .reshape(simsDataSum[0].shape)    
                SIM.free_all()
                del SIM

        if args.write_img:
            print "SAVING DATAFILE"
            h5name = "%s_rank%d_data%d.h5" % (ofile, worker_Id, i_data)
            h5name = os.path.join( odirj, h5name)
            fout = h5py.File(h5name,"w" ) 
            fout.create_dataset("bigsim_d9114", data=simsDataSum[0])
            fout.create_dataset("crystalAB", data=crystalAB.get_A() )
            fout.create_dataset("dataCryst", data=dataCryst.get_A() )
            fout.close()  

        if args.write_sim_img:
            print "SAVING DATAFILE"
            for i_sim in simsAB:
                sim_h5name = "%s_rank%d_sim%d_%d.h5" % (ofile, worker_Id, i_data, i_sim)
                sim_h5name = os.path.join( odirj, sim_h5name)
                
                fout = h5py.File(sim_h5name,"w" ) 
                fout.create_dataset("bigsim_d9114", 
                    data=simsAB[i_sim][0])
                fout.create_dataset("crystalAB", data=crystalAB.get_A() )
                fout.create_dataset("dataCryst", data=dataCryst.get_A() )
                fout.close()  

        print "RELFS FROM SIMS"
        refl_simA = spot_utils.refls_from_sims(simsAB[0], detector, beamA, thresh=thresh)
        refl_simB = spot_utils.refls_from_sims(simsAB[1], detector, beamB, thresh=thresh)

        if use_dials_spotter:
            print("DIALS SPOTTING")
            El = utils.explist_from_numpyarrays(simsDataSum,DET,beamA)
            refl_data = flex.reflection_table.from_observations(El, spot_par)
            print("Found %d refls using DIALS spot finder" % len(refl_data)) 
        else:
            refl_data = spot_utils.refls_from_sims(simsDataSum, detector, beamA,\
                            thresh=thresh)

            print ("Found %d refls using threshold" % len(refl_data))
        
        if len(refl_data)==0:
            print "Rank %d: No reflections found! " % (worker_Id)
            continue
        
        residA = metrics.check_indexable2(
            refl_data, refl_simA, detector, beamA, crystalAB, hkl_tol)
        residB = metrics.check_indexable2(
            refl_data, refl_simB, detector, beamB, crystalAB, hkl_tol)

        FA = sfall_main[19] # utils.open_flex('SA.pkl')  # ground truth values
        FB = sfall_main[110] #utils.open_flex('SB.pkl')  # ground truth values
        HA = tuple([hkl for hkl in FA.indices()])
        HB = tuple([hkl for hkl in FB.indices()])
        
        HA_val_map = { h:data for h,data in izip(FA.indices(), FA.data())}
        HB_val_map = { h:data for h,data in izip(FB.indices(), FB.data())}
        Hmaps = [HA_val_map, HB_val_map] 

        def get_val_at_hkl(hkl, val_map):
            sg96 = sgtbx.space_group(" P 4nw 2abw")
            poss_equivs = [i.h() for i in
                           miller.sym_equiv_indices(sg96, hkl).indices()]
            in_map=False
            for hkl2 in poss_equivs:
                if hkl2 in val_map:  # fast lookup
                    in_map=True
                    break
            if in_map:
                return hkl2, val_map[hkl2]
            else:
                return (None,None,None),-1
        
        if use_data_spec:
            print "Setting LA, LB as sums over flux regions A,B"
            LA = data_fluxes[:75].sum()
            LB = data_fluxes[75:].sum()
        else:
            print "Setting LA LB as data_fluxes"
            LA = data_fluxes[0] 
            LB = data_fluxes[1] 
        
        K=FF[0] ** 2 * FLUX[0]
        
        L_at_color = [LA,LB] 

        out = spot_utils.integrate_boxes(
            refl_data, simsDataSum[0], [refl_simA, refl_simB], DET,
            [beamA,beamB], crystalAB, delta_q=args.deltaq, gain=GAIN )

        Nh = len(out[0])
        rhs = []
        lhs = []
        all_H2 = []
        all_PA = []
        all_PB = []
        all_FA = []
        all_FB = []
        for i in range(Nh):

            HKL = out[0][i]
            yobs = out[1][i]
            Pvals = out[2][i]
            ycalc = 0 

            for i_P, P in enumerate(Pvals):
                L = L_at_color[i_P]
                H2, F = get_val_at_hkl(HKL, Hmaps[i_P])
                if i_P==0:
                    all_FA.append(F)
                else:
                    all_FB.append(F)

                ycalc += SCALE*L*P*abs(F)**2/K
            all_PA.append( Pvals[0])
            all_PB.append( Pvals[1]) 
            all_H2.append(H2)
            rhs.append(ycalc)
            lhs.append(yobs)
        
        df = pandas.DataFrame({"rhs":rhs, "lhs": lhs, 
            "PA":all_PA, "PB":all_PB, "FA": all_FA, 
            "FB": all_FB})

        df["run"] = run
        df["shot_idx"] = shot_idx
        df['gain'] = SCALE 

        df['LA'] = LA
        df['LB'] = LB
        df['K'] = K 
        df['init_rot'] = init_rot
       
        h,k,l = zip(*all_H2) 
        df['h2'] = h
        df['k2'] = k
        df['l2'] = l

        df.to_pickle(pklname)

        if args.plot:
            print("PLOT")
            import pylab as plt
            plt.plot( df.lhs, df.rhs, '.')
            plt.show()
        print("DonDonee")
def main(rank):

    device_Id = rank % ngpu

    worker_Id = node_id * ngpu + rank

    import os
    import sys
    from copy import deepcopy
    import glob
    from itertools import izip

    from scipy.spatial import distance
    import h5py
    import scipy.ndimage
    from IPython import embed
    import numpy as np
    import pandas
    from scipy.spatial import cKDTree

    from simtbx.nanoBragg import shapetype, nanoBragg
    from libtbx.phil import parse
    from scitbx.matrix import sqr
    import dxtbx
    from dxtbx.model.experiment_list import ExperimentListFactory
    from dxtbx.model.crystal import CrystalFactory
    from dials.algorithms.indexing.compare_orientation_matrices \
            import rotation_matrix_differences
    from dials.array_family import flex
    from dials.command_line.find_spots import phil_scope as find_spots_phil_scope

    from cxid9114.refine import metrics
    from cxid9114 import utils
    from cxid9114.geom import geom_utils
    from cxid9114.spots import integrate, spot_utils
    from cxid9114 import parameters
    from cxid9114.sim import sim_utils
    from cctbx import miller, sgtbx
    from cxid9114 import utils
    from cxid9114.bigsim import sim_spectra
    from cxid9114.refine.jitter_refine import make_param_list

    spot_par = find_spots_phil_scope.fetch(source=parse("")).extract()
    spot_par.spotfinder.threshold.dispersion.global_threshold = 40
    spot_par.spotfinder.threshold.dispersion.gain = 28
    spot_par.spotfinder.threshold.dispersion.kernel_size = [2, 2]
    spot_par.spotfinder.threshold.dispersion.sigma_strong = 1
    spot_par.spotfinder.threshold.dispersion.sigma_background = 6
    spot_par.spotfinder.filter.min_spot_size = 3
    spot_par.spotfinder.force_2d = True

    odir = args.odir
    odirj = os.path.join(odir, "job%d" % worker_Id)
    #all_pkl_files = [s for sl in \
    #    [ files for _,_, files in  os.walk(odir)]\
    #        for s in sl if s.endswith("pkl")]

    #print "Found %d pkl files already in %s!" \
    #    % (len(all_pkl_files), odir)

    if not os.path.exists(odirj):
        os.makedirs(odirj)

    hkl_tol = .15
    run = 61
    shot_idx = 0
    ENERGIES = [parameters.ENERGY_LOW,
                parameters.ENERGY_HIGH]  # colors of the beams
    FF = [10000, None]

    cryst_descr = {
        '__id__': 'crystal',
        'real_space_a': (79, 0, 0),
        'real_space_b': (0, 79, 0),
        'real_space_c': (0, 0, 38),
        'space_group_hall_symbol': '-P 4 2'
    }
    crystalAB = CrystalFactory.from_dict(cryst_descr)

    sfall_main = sim_spectra.load_spectra("../bigsim/test_sfall.h5")
    FFdat = [sfall_main[19], sfall_main[110]]

    FLUX = [1e11, 1e11]  # fluxes of the beams

    chanA_flux = 1e11
    chanB_flux = 1e11
    FLUXdat = [chanA_flux, chanB_flux]
    GAIN = 1

    waveA = parameters.ENERGY_CONV / ENERGIES[0]
    waveB = parameters.ENERGY_CONV / ENERGIES[1]

    from cxid9114.bigsim.bigsim_geom import DET, BEAM

    detector = DET

    print("Rank %d Begin" % worker_Id)
    for i_data in range(args.num_trials):
        pklname = "%s_rank%d_data%d.pkl" % (ofile, worker_Id, i_data)
        pklname = os.path.join(odirj, pklname)

        print("<><><><><><><")
        print("Job %d:  trial  %d / %d" %
              (worker_Id, i_data + 1, args.num_trials))
        print("<><><><><><><")

        if (worker_Id == 0 and i_data % smi_stride == 0 and cuda):
            print("GPU status")
            os.system("nvidia-smi")

            print("\n\n")
            print("CPU memory usage")
            mem_usg = """ps -U dermen --no-headers -o rss | awk '{ sum+=$1} END {print int(sum/1024) "MB consumed by CPU user"}'"""
            os.system(mem_usg)

        beamA = deepcopy(BEAM)
        beamB = deepcopy(BEAM)
        beamA.set_wavelength(waveA)
        beamB.set_wavelength(waveB)

        np.random.seed(args.seed)
        crystalAB = CrystalFactory.from_dict(cryst_descr)
        randnums = np.random.random(3)
        Rrand = random_rotation(1, randnums)
        crystalAB.set_U(Rrand.ravel())

        #pert = np.random.uniform(0.0001/2/np.pi, 0.0003 / 2. /np.pi)
        #print("PERT %f" % pert)
        #Rsmall = random_rotation(0.00001, randnums ) #pert)

        params_lst = make_param_list(crystalAB,
                                     DET,
                                     BEAM,
                                     1,
                                     rot=0.08,
                                     cell=.0000001,
                                     eq=(1, 1, 0),
                                     min_Ncell=23,
                                     max_Ncell=24,
                                     min_mos_spread=0.02,
                                     max_mos_spread=0.08)
        Ctruth = params_lst[0]['crystal']

        print Ctruth.get_unit_cell().parameters()
        print crystalAB.get_unit_cell().parameters()

        init_comp = rotation_matrix_differences((Ctruth, crystalAB))
        init_rot = float(init_comp.split("\n")[-2].split()[2])

        if use_data_spec:
            print "NOT IMPLEMENTED, Using a phony 2col spectrum to simulate the data"
            data_fluxes = FLUXdat
            data_energies = [parameters.ENERGY_LOW, parameters.ENERGY_HIGH]
            data_ff = FFdat
        else:
            print "Using a phony two color spectrum to simulate the data"
            data_fluxes = FLUXdat
            data_energies = [parameters.ENERGY_LOW, parameters.ENERGY_HIGH]
            data_ff = FFdat

        print("Truth crystal Misorientation deviation: %f deg" % init_rot)
        if args.truth_cryst:
            print "Using truth crystal"
            dataCryst = Ctruth
        else:
            print "Not using truth crystal"
            dataCryst = crystalAB

        if not make_background:
            print "SIMULATING Flat-Fhkl IMAGES"
            simsAB = sim_utils.sim_twocolors2(
                crystalAB,
                detector,
                BEAM,
                FF, [parameters.ENERGY_LOW, parameters.ENERGY_HIGH],
                FLUX,
                pids=None,
                Gauss=Gauss,
                cuda=cuda,
                oversample=oversample,
                Ncells_abc=Ncells_abc,
                mos_dom=mos_doms,
                mos_spread=mos_spread,
                exposure_s=exposure_s,
                beamsize_mm=beamsize_mm,
                device_Id=device_Id,
                boost=boost)

        if make_background:
            print("MAKING BACKGROUND")
            spec_file = h5py.File("../bigsim/simMe_data_run62.h5", "r")
            ave_spec = np.mean(spec_file["hist_spec"][()], axis=0)
            data_fluxes = [ave_spec[19], ave_spec[110]]
            data_energies = spec_file["energy_bins"][()][[19, 110]]
            data_ff = [1, 1]  #*len(data_energies)
            only_water = True
        else:
            only_water = False

        print "SIULATING DATA IMAGE"
        print data_fluxes
        simsDataSum = sim_utils.sim_twocolors2(dataCryst,
                                               detector,
                                               BEAM,
                                               data_ff,
                                               data_energies,
                                               data_fluxes,
                                               pids=None,
                                               Gauss=Gauss,
                                               cuda=cuda,
                                               oversample=oversample,
                                               Ncells_abc=Ncells_abc,
                                               accumulate=True,
                                               mos_dom=mos_doms,
                                               mos_spread=mos_spread,
                                               boost=boost,
                                               exposure_s=exposure_s,
                                               beamsize_mm=beamsize_mm,
                                               only_water=only_water,
                                               device_Id=device_Id)

        simsDataSum = np.array(simsDataSum)

        if make_background:
            bg_out = h5py.File(bg_name, "w")
            bg_out.create_dataset("bigsim_d9114", data=simsDataSum[0])
            print "Background made! Saved to file %s" % bg_name
            sys.exit()

        if add_background:
            print("ADDING BG")
            background = h5py.File(bg_name, "r")['bigsim_d9114'][()]
            bg_scale = np.sum([39152412349.12075, 32315440627.406036])
            bg_scale = np.sum(data_fluxes) / bg_scale
            print "%.3e backgorund scale" % bg_scale
            print "BG shape", background.shape
            simsDataSum[0] += background * bg_scale

        if add_noise:
            print("ADDING NOISE")
            for pidx in range(1):
                SIM = nanoBragg(detector=DET, beam=BEAM, panel_id=pidx)
                SIM.exposure_s = exposure_s
                SIM.beamsize_mm = beamsize_mm
                SIM.flux = np.sum(data_fluxes)
                SIM.detector_psf_kernel_radius_pixels = 5
                SIM.detector_psf_type = shapetype.Unknown  # for CSPAD
                SIM.detector_psf_fwhm_mm = 0
                SIM.quantum_gain = 28
                SIM.raw_pixels = flex.double(simsDataSum[pidx].ravel())
                SIM.add_noise()
                simsDataSum[pidx] = SIM.raw_pixels.as_numpy_array()\
                    .reshape(simsDataSum[0].shape)
                SIM.free_all()
                del SIM

        if args.write_img:
            print "SAVING DATAFILE"
            h5name = "%s_rank%d_data%d.h5" % (ofile, worker_Id, i_data)
            h5name = os.path.join(odirj, h5name)
            fout = h5py.File(h5name, "w")
            fout.create_dataset("bigsim_d9114", data=simsDataSum[0])
            fout.create_dataset("crystalAB", data=crystalAB.get_A())
            fout.create_dataset("dataCryst", data=dataCryst.get_A())
            fout.close()

        if args.write_sim_img:
            print "SAVING DATAFILE"
            for i_sim in simsAB:
                sim_h5name = "%s_rank%d_sim%d_%d.h5" % (ofile, worker_Id,
                                                        i_data, i_sim)
                sim_h5name = os.path.join(odirj, sim_h5name)
                from IPython import embed
                embed()

                fout = h5py.File(sim_h5name, "w")
                fout.create_dataset("bigsim_d9114", data=simsAB[i_sim][0])
                fout.create_dataset("crystalAB", data=crystalAB.get_A())
                fout.create_dataset("dataCryst", data=dataCryst.get_A())
                fout.close()

        print "RELFS FROM SIMS"
        refl_simA = spot_utils.refls_from_sims(simsAB[0],
                                               detector,
                                               beamA,
                                               thresh=thresh)
        refl_simB = spot_utils.refls_from_sims(simsAB[1],
                                               detector,
                                               beamB,
                                               thresh=thresh)

        if use_dials_spotter:
            print("DIALS SPOTTING")
            El = utils.explist_from_numpyarrays(simsDataSum, DET, beamA)
            refl_data = flex.reflection_table.from_observations(El, spot_par)
            print("Found %d refls using DIALS spot finder" % len(refl_data))
        else:
            refl_data = spot_utils.refls_from_sims(simsDataSum, detector, beamA,\
                            thresh=thresh)

            print("Found %d refls using threshold" % len(refl_data))

        if len(refl_data) == 0:
            print "Rank %d: No reflections found! " % (worker_Id)
            continue

        residA = metrics.check_indexable2(refl_data, refl_simA, detector,
                                          beamA, crystalAB, hkl_tol)
        residB = metrics.check_indexable2(refl_data, refl_simB, detector,
                                          beamB, crystalAB, hkl_tol)

        sg96 = sgtbx.space_group(" P 4nw 2abw")
        FA = sfall_main[19]  # utils.open_flex('SA.pkl')  # ground truth values
        FB = sfall_main[110]  #utils.open_flex('SB.pkl')  # ground truth values
        HA = tuple([hkl for hkl in FA.indices()])
        HB = tuple([hkl for hkl in FB.indices()])

        HA_val_map = {h: data for h, data in izip(FA.indices(), FA.data())}
        HB_val_map = {h: data for h, data in izip(FB.indices(), FB.data())}

        def get_val_at_hkl(hkl, val_map):
            poss_equivs = [
                i.h() for i in miller.sym_equiv_indices(sg96, hkl).indices()
            ]
            in_map = False
            for hkl2 in poss_equivs:
                if hkl2 in val_map:  # fast lookup
                    in_map = True
                    break
            if in_map:
                return hkl2, val_map[hkl2]
            else:
                return (None, None, None), -1

        filt = 1  #True #`False #True
        if filt:
            _, all_HiA, _ = spot_utils.refls_to_hkl(refl_simA,
                                                    detector,
                                                    beamA,
                                                    crystal=crystalAB,
                                                    returnQ=True)
            all_treeA = cKDTree(all_HiA)
            nnA = all_treeA.query_ball_point(all_HiA, r=1e-7)

            _, all_HiB, _ = spot_utils.refls_to_hkl(refl_simB,
                                                    detector,
                                                    beamB,
                                                    crystal=crystalAB,
                                                    returnQ=True)
            all_treeB = cKDTree(all_HiB)
            nnB = all_treeB.query_ball_point(all_HiB, r=1e-7)

            NreflA = len(refl_simA)
            NreflB = len(refl_simB)

            drop_meA = []
            for i, vals in enumerate(nnA):
                if i in drop_meA:
                    continue
                if len(vals) > 1:
                    pids = [refl_simA[v]['panel'] for v in vals]
                    if len(set(pids)) == 1:
                        refl_vals = refl_simA.select(
                            flex.bool(
                                [i_v in vals for i_v in np.arange(NreflA)]))
                        x, y, z = spot_utils.xyz_from_refl(refl_vals)
                        allI = [r['intensity.sum.value'] for r in refl_vals]
                        allI = sum(allI)
                        xm = np.mean(x)
                        ym = np.mean(y)
                        zm = np.mean(z)
                        drop_meA.extend(vals[1:])
                        x1b, x2b, y1b, y2b, z1b, z2b = zip(
                            *[r['bbox'] for r in refl_vals])
                        keep_me = vals[0]
                        # indexing order is important to modify as reference
                        refl_simA['intensity.sum.value'][keep_me] = allI
                        refl_simA['xyzobs.px.value'][keep_me] = (xm, ym, zm)
                        refl_simA['bbox'][keep_me] = (min(x1b), max(x2b),\
                                        min(y1b), max(y2b), min(z1b), max(z2b))
                    else:
                        drop_meA.append(vals)
                    print vals

            if drop_meA:
                keep_meA = np.array([i not in drop_meA for i in range(NreflA)])
                refl_simA = refl_simA.select(flex.bool(keep_meA))
                NreflA = len(refl_simA)

            drop_meB = []
            for i, vals in enumerate(nnB):
                if i in drop_meB:
                    continue
                if len(vals) > 1:
                    pids = [refl_simB[v]['panel'] for v in vals]
                    if len(set(pids)) == 1:
                        print vals
                        # merge_spots(vals)
                        refl_vals = refl_simB.select(
                            flex.bool(
                                [i_v in vals for i_v in np.arange(NreflB)]))
                        x, y, z = spot_utils.xyz_from_refl(refl_vals)
                        allI = [r['intensity.sum.value'] for r in refl_vals]
                        allI = sum(allI)
                        xm = np.mean(x)
                        ym = np.mean(y)
                        zm = np.mean(z)
                        drop_meB.extend(vals[1:])
                        x1b, x2b, y1b, y2b, z1b, z2b = zip(
                            *[r['bbox'] for r in refl_vals])
                        keep_me = vals[0]
                        refl_simB['intensity.sum.value'][keep_me] = allI
                        refl_simB['xyzobs.px.value'][keep_me] = (xm, ym, zm)
                        refl_simB['bbox'][keep_me] = (min(x1b), max(x2b), min(y1b),\
                                        max(y2b), min(z1b), max(z2b))
                    else:
                        drop_meB.append(vals)
                    print vals
            if drop_meB:
                keep_meB = [i not in drop_meB for i in range(NreflB)]
                refl_simB = refl_simB.select(flex.bool(keep_meB))
                NreflB = len(refl_simB)

##          remake the trees given the drops
            _, all_HiA = spot_utils.refls_to_hkl(refl_simA,
                                                 detector,
                                                 beamA,
                                                 crystal=crystalAB,
                                                 returnQ=False)
            all_treeA = cKDTree(all_HiA)

            _, all_HiB = spot_utils.refls_to_hkl(refl_simB,
                                                 detector,
                                                 beamB,
                                                 crystal=crystalAB,
                                                 returnQ=False)
            #all_treeB = cKDTree(all_HiB)

            ##          CHECK if same HKL, indexed by both colors
            #           exists on multiple panels, and if so, delete...
            nnAB = all_treeA.query_ball_point(all_HiB, r=1e-7)
            drop_meA = []
            drop_meB = []
            for iB, iA_vals in enumerate(nnAB):
                if len(iA_vals) > 0:
                    assert (len(iA_vals) == 1)
                    iA = iA_vals[0]
                    pidA = refl_simA[iA]['panel']
                    pidB = refl_simB[iB]['panel']
                    if pidA != pidB:
                        drop_meA.append(iA)
                        drop_meB.append(iB)

            if drop_meA:
                keep_meA = [i not in drop_meA for i in range(NreflA)]
                refl_simA = refl_simA.select(flex.bool(keep_meA))
            if drop_meB:
                keep_meB = [i not in drop_meB for i in range(NreflB)]
                refl_simB = refl_simB.select(flex.bool(keep_meB))

# ----  Done with edge case filters#

# reflections per panel
        rpp = spot_utils.refls_by_panelname(refl_data)
        rppA = spot_utils.refls_by_panelname(refl_simA)
        rppB = spot_utils.refls_by_panelname(refl_simB)

        DATA = {
            "D": [],
            "IA": [],
            "IB": [],
            "h2": [],
            "k2": [],
            "l2": [],
            "h": [],
            "k": [],
            "l": [],
            "PA": [],
            "PB": [],
            "FA": [],
            "FB": [],
            "iA": [],
            "iB": [],
            "Nstrong": [],
            "pid": [],
            "delta_pix": [],
            "deltaX": [],
            "deltaY": []
        }
        all_int_me = []

        # now set up boundboxes and integrate
        if tilt_plane_integration:
            mask = np.ones(simsDataSum.shape).astype(np.bool)
            print "Using tilt plane integration!"
        else:
            print "Not using tilt plane integration, just basic spot thresh integration "
        for pid in rpp:
            if tilt_plane_integration:
                Is, Ibk, noise, pix_per = \
                    integrate.integrate3(
                        rpp[pid],
                        mask[pid],
                        simsDataSum[pid],
                        gain=28) #nom_gain)

            R = rpp[pid]
            if pid in rppA:  # are there A-channel reflections on this panel
                inA = True
                RA = rppA[pid]
                xA, yA, _ = spot_utils.xyz_from_refl(RA)
                pointsA = np.array(zip(xA, yA))
                HA, HiA, QA = spot_utils.refls_to_hkl(RA,
                                                      detector,
                                                      beamA,
                                                      crystal=crystalAB,
                                                      returnQ=True)
            else:
                inA = False

            if pid in rppB:  # are there B channel reflections on this channel
                inB = True
                RB = rppB[pid]
                xB, yB, _ = spot_utils.xyz_from_refl(RB)
                pointsB = np.array(zip(xB, yB))
                HB, HiB, QB = spot_utils.refls_to_hkl(RB,
                                                      detector,
                                                      beamB,
                                                      crystal=crystalAB,
                                                      returnQ=True)
            else:
                inB = False

            x, y, _ = spot_utils.xyz_from_refl(R)
            x = np.array(x)
            y = np.array(y)

            panX, panY = detector[pid].get_image_size()

            mergesA = []
            mergesB = []
            if inA and inB:  # are there both A and B channel reflections ? If so, lets find out which ones have same hkl
                # make tree structure for merging the spots
                treeA = cKDTree(pointsA)
                treeB = cKDTree(pointsB)

                QA = geom_utils.res_on_panel(detector[pid], beamA)
                QAmag = np.linalg.norm(QA, axis=2) * 2 * np.pi
                detdist = detector[pid].get_distance()
                pixsize = detector[pid].get_pixel_size()[0]
                merge_me = []
                for p in pointsA:
                    iix, iiy = int(p[0]), int(p[1])
                    q = QAmag[iiy, iix]
                    radA = detdist * np.tan(
                        2 * np.arcsin(q * waveA / 4 / np.pi)) / pixsize
                    radB = detdist * np.tan(
                        2 * np.arcsin(q * waveB / 4 / np.pi)) / pixsize
                    rmax = np.abs(radA - radB)
                    split_spot_pairs = treeB.query_ball_point(x=p, r=rmax + sz)
                    merge_me.append(split_spot_pairs)

                #rmax = geom_utils.twocolor_deltapix(detector[pid], beamA, beamB)
                #merge_me = treeA.query_ball_tree(treeB, r=rmax + sz)

                for iA, iB in enumerate(merge_me):
                    if not iB:
                        continue
                    iB = iB[0]

                    # check that the miller indices are the same
                    if not all([i == j for i, j in zip(HiA[iA], HiB[iB])]):
                        continue
                    x1A, x2A, y1A, y2A, _, _ = RA[iA]['bbox']  # shoebox'].bbox
                    x1B, x2B, y1B, y2B, _, _ = RB[iB]['bbox']  # shoebox'].bbox

                    xlow = max([0, min((x1A, x1B)) - sz])
                    xhigh = min([panX, max((x2A, x2B)) + sz])
                    ylow = max([0, min((y1A, y1B)) - sz])
                    yhigh = min([panY, max((y2A, y2B)) + sz])

                    #if iA==79:
                    #    embed()
                    # integrate me if I am in the bounding box!
                    int_me = np.where((xlow < x) & (x < xhigh) & (ylow < y)
                                      & (y < yhigh))[0]
                    if not int_me.size:
                        continue
                    mergesA.append(iA)
                    mergesB.append(iB)

                    # integrate the spot, this will change depending on data or simulation
                    totalI = 0
                    totalCOM = 0
                    for ref_idx in int_me:
                        if tilt_plane_integration:
                            totalI += Is[ref_idx]
                        else:
                            totalI += rpp[pid][ref_idx]["intensity.sum.value"]
                        totalCOM += np.array(
                            rpp[pid][ref_idx]["xyzobs.px.value"])
                    totalCOM /= len(int_me)

                    PA = RA[iA]['intensity.sum.value']
                    PB = RB[iB]['intensity.sum.value']

                    # get the hkl structure factor, and the sym equiv hkl
                    (h, k, l) = HiA[iA]  # NOTE: same for A and B channels
                    (h2, k2, l2), FA = get_val_at_hkl((h, k, l), HA_val_map)

                    _, FB = get_val_at_hkl(
                        (h, k, l),
                        HB_val_map)  # NOTE: no need to return h2,k2,l2 twice
                    #if FB==-1 or FA==-1:
                    #    continue

                    DATA['h'].append(h)
                    DATA['k'].append(k)
                    DATA['l'].append(l)
                    DATA['h2'].append(h2)
                    DATA['k2'].append(k2)
                    DATA['l2'].append(l2)
                    DATA['D'].append(totalI)
                    DATA['PA'].append(PA)
                    DATA['PB'].append(PB)
                    DATA['FA'].append(FA)
                    DATA['FB'].append(FB)
                    DATA['IA'].append(abs(FA)**2)
                    DATA['IB'].append(abs(FB)**2)

                    DATA['pid'].append(pid)
                    DATA["Nstrong"].append(int_me.size)
                    DATA["iA"].append(iA)
                    DATA["iB"].append(iB)
                    all_int_me.append(int_me)

                    # NOTE: stash the sim-data distance (COM to COM)
                    posA = RA[iA]['xyzobs.px.value']
                    posB = RB[iB]['xyzobs.px.value']
                    simCOM = np.mean([posA, posB], axis=0)
                    DATA["delta_pix"].append(
                        distance.euclidean(totalCOM[:2], simCOM[:2]))
                    DATA["deltaX"].append(totalCOM[0] - simCOM[0])
                    DATA["deltaY"].append(totalCOM[1] - simCOM[1])

            if inA:
                for iA, ref in enumerate(RA):
                    if iA in mergesA:
                        continue
                    x1A, x2A, y1A, y2A, _, _ = RA[iA][
                        'bbox']  # ['shoebox'].bbox
                    xlow = max((0, x1A - sz))
                    xhigh = min((panX, x2A + sz))
                    ylow = max((0, y1A - sz))
                    yhigh = min((panY, y2A + sz))
                    int_me = np.where((xlow < x) & (x < xhigh) & (ylow < y)
                                      & (y < yhigh))[0]
                    if not int_me.size:
                        continue

                    totalI = 0
                    totalCOM = 0
                    for ref_idx in int_me:
                        if tilt_plane_integration:
                            totalI += Is[ref_idx]
                        else:
                            totalI += rpp[pid][ref_idx]["intensity.sum.value"]
                        totalCOM += np.array(
                            rpp[pid][ref_idx]["xyzobs.px.value"])
                    totalCOM /= len(int_me)

                    PA = RA[iA]['intensity.sum.value']
                    PB = 0  # crucial ;)

                    # get the hkl structure factor, and the sym equiv hkl
                    (h, k, l) = HiA[iA]  # NOTE: same for A and B channels
                    (h2, k2, l2), FA = get_val_at_hkl((h, k, l), HA_val_map)
                    _, FB = get_val_at_hkl(
                        (h, k, l),
                        HB_val_map)  # NOTE: no need to return h2,k2,l2 twice
                    #if FA==-1 or FB==-1:
                    #    continue
                    DATA['h'].append(h)
                    DATA['k'].append(k)
                    DATA['l'].append(l)
                    DATA['h2'].append(h2)
                    DATA['k2'].append(k2)
                    DATA['l2'].append(l2)
                    DATA['D'].append(totalI)
                    DATA['PA'].append(PA)
                    DATA['PB'].append(PB)
                    DATA['FA'].append(FA)
                    DATA['FB'].append(FB)
                    DATA['IA'].append(abs(FA)**2)
                    DATA['IB'].append(abs(FB)**2)

                    DATA['pid'].append(pid)
                    DATA["Nstrong"].append(int_me.size)
                    DATA["iA"].append(iA)
                    DATA["iB"].append(np.nan)
                    all_int_me.append(int_me)

                    # NOTE: stash the sim-data distance (COM to COM)
                    simCOM = np.array(RA[iA]['xyzobs.px.value'])
                    DATA["delta_pix"].append(
                        distance.euclidean(totalCOM[:2], simCOM[:2]))
                    DATA["deltaX"].append(totalCOM[0] - simCOM[0])
                    DATA["deltaY"].append(totalCOM[1] - simCOM[1])

            if inB:
                for iB, ref in enumerate(RB):
                    if iB in mergesB:
                        continue
                    x1B, x2B, y1B, y2B, _, _ = RB[iB]['bbox']  # shoebox'].bbox
                    xlow = max((0, x1B - sz))
                    xhigh = min((panX, x2B + sz))
                    ylow = max((0, y1B - sz))
                    yhigh = min((panY, y2B + sz))
                    # subimg = simsDataSum[pid][ylow:yhigh, xlow:xhigh]
                    # bg = 0
                    int_me = np.where((xlow < x) & (x < xhigh) & (ylow < y)
                                      & (y < yhigh))[0]
                    if not int_me.size:
                        continue

                    totalI = 0
                    totalCOM = 0
                    for ref_idx in int_me:
                        if tilt_plane_integration:
                            totalI += Is[ref_idx]
                        else:
                            totalI += rpp[pid][ref_idx]["intensity.sum.value"]
                        totalCOM += np.array(
                            rpp[pid][ref_idx]["xyzobs.px.value"])
                    totalCOM /= len(int_me)

                    PA = 0  # crucial ;)
                    PB = RB[iB]['intensity.sum.value']

                    # get the hkl structure factor, and the sym equiv hkl
                    (h, k, l) = HiB[iB]  # NOTE: same for A and B channels
                    (h2, k2, l2), FB = get_val_at_hkl((h, k, l), HB_val_map)
                    _, FA = get_val_at_hkl(
                        (h, k, l),
                        HA_val_map)  # NOTE: no need to return h2,k2,l2 twice
                    #if FA==-1 or FB==-1:
                    #    continue
                    DATA['h'].append(h)
                    DATA['k'].append(k)
                    DATA['l'].append(l)
                    DATA['h2'].append(h2)
                    DATA['k2'].append(k2)
                    DATA['l2'].append(l2)
                    DATA['D'].append(totalI)
                    DATA['PA'].append(PA)
                    DATA['PB'].append(PB)
                    DATA['FA'].append(FA)
                    DATA['FB'].append(FB)
                    DATA['IA'].append(abs(FA)**2)
                    DATA['IB'].append(abs(FB)**2)

                    DATA['pid'].append(pid)
                    DATA["Nstrong"].append(int_me.size)
                    DATA["iA"].append(np.nan)
                    DATA["iB"].append(iB)
                    all_int_me.append(int_me)
                    # NOTE: stash the sim-data distance (COM to COM)
                    simCOM = np.array(RB[iB]['xyzobs.px.value'])
                    DATA["delta_pix"].append(
                        distance.euclidean(totalCOM[:2], simCOM[:2]))
                    DATA["deltaX"].append(totalCOM[0] - simCOM[0])
                    DATA["deltaY"].append(totalCOM[1] - simCOM[1])

        df = pandas.DataFrame(DATA)
        df["run"] = run
        df["shot_idx"] = shot_idx
        df['gain'] = GAIN

        if use_data_spec:
            print "Setting LA, LB as sums over flux regions A,B"
            df['LA'] = data_fluxes[:75].sum()
            df['LB'] = data_fluxes[75:].sum()
        else:
            print "Setting LA LB as data_fluxes"
            df['LA'] = data_fluxes[0]
            df["LB"] = data_fluxes[1]

        df['K'] = FF[0]**2 * FLUX[0]
        df["rhs"] = df.gain * (df.IA * df.LA * (df.PA / df.K) + df.IB * df.LB *
                               (df.PB / df.K))
        df["lhs"] = df.D
        #df['data_name'] = data_name
        df['init_rot'] = init_rot
        df.to_pickle(pklname)

        print("PLOT")
        if args.plot:
            import pylab as plt
            plt.plot(df.lhs, df.rhs, '.')
            plt.show()
        print("DonDonee")
def main(jid):

    import sys
    import glob
    import os
    from copy import deepcopy

    import numpy as np
    import pandas
    from scipy.spatial import cKDTree, distance
    from IPython import embed
    import h5py

    from cctbx import miller, sgtbx
    from cxid9114 import utils
    from dials.array_family import flex
    import dxtbx
    from cxid9114 import utils
    from cxid9114.geom import geom_utils
    from cxid9114.spots import integrate, spot_utils
    from cxid9114 import parameters
    from cxid9114.sim import sim_utils
    from cxid9114.solvers import setup_inputs
    from cxid9114.refine import metrics

    from LS49.sim.step4_pad import microcrystal

    assert (iglob is not None)
    spec_f = h5py.File("simMe_data_run62.h5", "r")
    spec_data = spec_f["hist_spec"][()]
    sg96 = sgtbx.space_group(" P 4nw 2abw")
    ofile = "%s_liftoff_betelgeuse%d.%d.pdpkl" % (tag, jid + 1, Njobs)
    ofile = os.path.join(odir, ofile)
    print(ofile)

    file_list = glob.glob(iglob)
    Nfiles = len(file_list)

    ENERGIES = [parameters.ENERGY_LOW, parameters.ENERGY_HIGH]
    FF = [1e4, None]
    FLUX = [1e12, 1e12]
    beamsize_mm = 0.001
    Deff_A = 2200
    length_um = 2.2
    detector = utils.open_flex("bigsim_detect.pkl")
    beam = utils.open_flex("bigsim_beam.pkl")

    beamA = deepcopy(beam)
    beamB = deepcopy(beam)
    waveA = parameters.ENERGY_CONV / ENERGIES[0]
    waveB = parameters.ENERGY_CONV / ENERGIES[1]
    beamA.set_wavelength(waveA)
    beamB.set_wavelength(waveB)

    file_list_idx = np.array_split(np.arange(Nfiles), Njobs)

    crystal = microcrystal(Deff_A=Deff_A,
                           length_um=length_um,
                           beam_diameter_um=beamsize_mm * 1000,
                           verbose=False)

    all_dfs = []

    idxstart = file_list_idx[jid][0]
    Nfiles = len(file_list_idx[jid])
    if max_files is not None:
        Nfiles = min(max_files, Nfiles)

    for idx in range(idxstart, idxstart + Nfiles):

        if jid == 0 and idx % smi_stride == 0:
            print("GPU status")
            os.system("nvidia-smi")

            print("\n\n")
            print("CPU memory usage")
            mem_usg = """ps -U dermen --no-headers -o rss | awk '{ sum+=$1} END {print int(sum/1024) "MB consumed by CPU user"}'"""
            os.system(mem_usg)

        data_name = file_list[idx]
        data = utils.open_flex(data_name)
        shot_idx = int(data["img_f"].split("_")[-1].split(".")[0])

        print "Data file %s" % data_name

        shot_idx = int(shot_idx)

        shot_spectrum = spec_data[shot_idx]

        chanA_flux = shot_spectrum[10:25].sum()
        chanB_flux = shot_spectrum[100:115].sum()

        crystalAB = data["crystalAB"]

        print "Doing the basic simulation.."
        simsAB = sim_utils.sim_twocolors2(
            crystalAB,
            detector,
            beam,
            FF, [parameters.ENERGY_LOW, parameters.ENERGY_HIGH],
            FLUX,
            Gauss=True,
            oversample=0,
            Ncells_abc=(25, 25, 25),
            mos_dom=1000,
            mos_spread=0.015,
            cuda=cuda,
            device_Id=jid,
            beamsize_mm=beamsize_mm,
            boost=crystal.domains_per_crystal,
            exposure_s=1)

        print "Done!"

        refl_data = data["refls_strong"]

        #
        print "\n\n\n#######\nProcessing %d reflections read from the data file \n#####\n\n" % len(
            refl_data)

        refl_simA = spot_utils.refls_from_sims(simsAB[0],
                                               detector,
                                               beamA,
                                               thresh=thresh)
        refl_simB = spot_utils.refls_from_sims(simsAB[1],
                                               detector,
                                               beamB,
                                               thresh=thresh)

        residA = metrics.check_indexable2(refl_data, refl_simA, detector,
                                          beamA, crystalAB, hkl_tol)
        residB = metrics.check_indexable2(refl_data, refl_simB, detector,
                                          beamB, crystalAB, hkl_tol)

        print "Initial metrics suggest that:"
        print "\t %d reflections could be indexed by channeL A" % residA[
            'indexed'].sum()
        print "\t %d reflections could be indexed by channeL B" % residB[
            'indexed'].sum()
        print "\t NOw we can check for outliers.. "

        if plot_overlap:
            spot_utils.plot_overlap(refl_simA, refl_simB, refl_data, detector)

        d = {
            "crystalAB": crystalAB,
            "residA": residA,
            "residB": residB,
            "beamA": beamA,
            "beamB": beamB,
            "detector": detector,
            "refls_simA": refl_simA,
            "refls_simB": refl_simB,
            "refls_data": refl_data
        }

        # integrate with tilt plane subtraction
        print("LOADING THE FINE IMAGE")
        loader = dxtbx.load(data["img_f"])
        pan_data = np.array([loader.get_raw_data().as_numpy_array()])
        print(data["img_f"])

        # make a dummie mask
        mask = np.ones_like(pan_data).astype(np.bool)

        # before processing we need to check edge cases
        print "Checking the simulations edge cases, basically to do with the spot detection of simulations... \n\t such a pain.. "

        filt = True
        if filt:
            _, all_HiA, _ = spot_utils.refls_to_hkl(refl_simA,
                                                    detector,
                                                    beamA,
                                                    crystal=crystalAB,
                                                    returnQ=True)
            all_treeA = cKDTree(all_HiA)
            nnA = all_treeA.query_ball_point(all_HiA, r=1e-7)

            _, all_HiB, _ = spot_utils.refls_to_hkl(refl_simB,
                                                    detector,
                                                    beamB,
                                                    crystal=crystalAB,
                                                    returnQ=True)
            all_treeB = cKDTree(all_HiB)
            nnB = all_treeB.query_ball_point(all_HiB, r=1e-7)

            NreflA = len(refl_simA)
            NreflB = len(refl_simB)

            drop_meA = []
            for i, vals in enumerate(nnA):
                if i in drop_meA:
                    continue
                if len(vals) > 1:
                    pids = [refl_simA[v]['panel'] for v in vals]
                    if len(set(pids)) == 1:
                        refl_vals = refl_simA.select(
                            flex.bool(
                                [i_v in vals for i_v in np.arange(NreflA)]))
                        x, y, z = spot_utils.xyz_from_refl(refl_vals)
                        allI = [r['intensity.sum.value'] for r in refl_vals]
                        allI = sum(allI)
                        xm = np.mean(x)
                        ym = np.mean(y)
                        zm = np.mean(z)
                        drop_meA.extend(vals[1:])
                        x1b, x2b, y1b, y2b, z1b, z2b = zip(
                            *[r['bbox'] for r in refl_vals])
                        keep_me = vals[0]
                        # indexing order is important to modify as reference
                        refl_simA['intensity.sum.value'][keep_me] = allI
                        refl_simA['xyzobs.px.value'][keep_me] = (xm, ym, zm)
                        refl_simA['bbox'][keep_me] = (min(x1b), max(x2b),\
                                        min(y1b), max(y2b), min(z1b), max(z2b))
                    else:
                        drop_meA.append(vals)
                    print vals

            if drop_meA:
                keep_meA = np.array([i not in drop_meA for i in range(NreflA)])
                refl_simA = refl_simA.select(flex.bool(keep_meA))
                NreflA = len(refl_simA)

            drop_meB = []
            for i, vals in enumerate(nnB):
                if i in drop_meB:
                    continue
                if len(vals) > 1:
                    pids = [refl_simB[v]['panel'] for v in vals]
                    if len(set(pids)) == 1:
                        print vals
                        # merge_spots(vals)
                        refl_vals = refl_simB.select(
                            flex.bool(
                                [i_v in vals for i_v in np.arange(NreflB)]))
                        x, y, z = spot_utils.xyz_from_refl(refl_vals)
                        allI = [r['intensity.sum.value'] for r in refl_vals]
                        allI = sum(allI)
                        xm = np.mean(x)
                        ym = np.mean(y)
                        zm = np.mean(z)
                        drop_meB.extend(vals[1:])
                        x1b, x2b, y1b, y2b, z1b, z2b = zip(
                            *[r['bbox'] for r in refl_vals])
                        keep_me = vals[0]
                        refl_simB['intensity.sum.value'][keep_me] = allI
                        refl_simB['xyzobs.px.value'][keep_me] = (xm, ym, zm)
                        refl_simB['bbox'][keep_me] = (min(x1b), max(x2b), min(y1b),\
                                        max(y2b), min(z1b), max(z2b))
                    else:
                        drop_meB.append(vals)
                    print vals
            if drop_meB:
                keep_meB = [i not in drop_meB for i in range(NreflB)]
                refl_simB = refl_simB.select(flex.bool(keep_meB))
                NreflB = len(refl_simB)

            ##  remake the trees given the drops
            _, all_HiA = spot_utils.refls_to_hkl(refl_simA,
                                                 detector,
                                                 beamA,
                                                 crystal=crystalAB,
                                                 returnQ=False)
            all_treeA = cKDTree(all_HiA)

            _, all_HiB = spot_utils.refls_to_hkl(refl_simB,
                                                 detector,
                                                 beamB,
                                                 crystal=crystalAB,
                                                 returnQ=False)
            #all_treeB = cKDTree(all_HiB)

            ##  CHECK if same HKL, indexed by both colors
            #   exists on multiple panels, and if so, delete...
            nnAB = all_treeA.query_ball_point(all_HiB, r=1e-7)
            drop_meA = []
            drop_meB = []
            for iB, iA_vals in enumerate(nnAB):
                if len(iA_vals) > 0:
                    assert (len(iA_vals) == 1)
                    iA = iA_vals[0]
                    pidA = refl_simA[iA]['panel']
                    pidB = refl_simB[iB]['panel']
                    if pidA != pidB:
                        drop_meA.append(iA)
                        drop_meB.append(iB)

            if drop_meA:
                keep_meA = [i not in drop_meA for i in range(NreflA)]
                refl_simA = refl_simA.select(flex.bool(keep_meA))
            if drop_meB:
                keep_meB = [i not in drop_meB for i in range(NreflB)]
                refl_simB = refl_simB.select(flex.bool(keep_meB))

        # ----  Done with edge case filters#
        print "<><><><>\nI am doing checking the simulations for edge cases!\n<><><><>"

        # reflections per panel
        rpp = spot_utils.refls_by_panelname(refl_data)
        rppA = spot_utils.refls_by_panelname(refl_simA)
        rppB = spot_utils.refls_by_panelname(refl_simB)

        DATA = {
            "D": [],
            "Dnoise": [],
            "h": [],
            "k": [],
            "l": [],
            "is_pos": [],
            "hAnom": [],
            "kAnom": [],
            "lAnom": [],
            "horig": [],
            "korig": [],
            "lorig": [],
            "PA": [],
            "PB": [],
            "iA": [],
            "iB": [],
            "Nstrong": [],
            "pid": [],
            "delta_pix": []
        }  # NOTE: added in the delta pix
        # for comparing sim and data center of masses

        all_int_me = []
        sz_fudge = sz = 5  # integration fudge factor to include spots that dont overlap perfectly with predictions
        # double define for convenience cause sz is easier to type than sz_fudge

        #  now set up boundboxes and integrate
        for idx_pid, pid in enumerate(rpp):
            # NOTE: integrate the spots for this panel
            Is, Ibk, noise, pix_per = integrate.integrate3(rpp[pid],
                                                           mask[pid],
                                                           pan_data[pid],
                                                           gain=nom_gain)

            print "Processing peaks on CSPAD panel %d (%d / %d)" % (
                pid, idx_pid, len(rpp))
            R = rpp[pid]
            if pid in rppA:  # are there A-channel reflections on this panel
                inA = True
                RA = rppA[pid]
                xA, yA, _ = spot_utils.xyz_from_refl(RA)
                pointsA = np.array(zip(xA, yA))
                HA, HiA, QA = spot_utils.refls_to_hkl(RA,
                                                      detector,
                                                      beamA,
                                                      crystal=crystalAB,
                                                      returnQ=True)
            else:
                inA = False

            if pid in rppB:  # are there B channel reflections on this channel
                inB = True
                RB = rppB[pid]
                xB, yB, _ = spot_utils.xyz_from_refl(RB)
                pointsB = np.array(zip(xB, yB))
                HB, HiB, QB = spot_utils.refls_to_hkl(RB,
                                                      detector,
                                                      beamB,
                                                      crystal=crystalAB,
                                                      returnQ=True)
            else:
                inB = False

            x, y, _ = spot_utils.xyz_from_refl(R)
            x = np.array(x)
            y = np.array(y)

            panX, panY = detector[pid].get_image_size()

            mergesA = []
            mergesB = []
            if inA and inB:  # are there both A and B channel reflections ? If so, lets find out which ones have same hkl
                # make tree structure for merging the spots
                treeA = cKDTree(pointsA)
                treeB = cKDTree(pointsB)
                # how far apart should the two color spots be ?
                # NOTE: this is the critical step - are the spots within rmax - and if so they are considered indexed..
                rmax = geom_utils.twocolor_deltapix(detector[pid], beamA,
                                                    beamB)
                merge_me = treeA.query_ball_tree(
                    treeB, r=rmax + sz_fudge)  # slap on some fudge
                # if pixels points in treeA are within rmax + sz_fugde of
                # points in treeB, then these points are assumed to be overlapped
                for iA, iB in enumerate(merge_me):
                    if not iB:
                        continue
                    iB = iB[0]

                    # check that the miller indices are the same
                    if not all([i == j for i, j in zip(HiA[iA], HiB[iB])]):
                        continue
                    x1A, x2A, y1A, y2A, _, _ = RA[iA]['bbox']  # shoebox'].bbox
                    x1B, x2B, y1B, y2B, _, _ = RB[iB]['bbox']  # shoebox'].bbox

                    xlow = max([0, min((x1A, x1B)) - sz])
                    xhigh = min([panX, max((x2A, x2B)) + sz])
                    ylow = max([0, min((y1A, y1B)) - sz])
                    yhigh = min([panY, max((y2A, y2B)) + sz])

                    # integrate me if I am in the bounding box!
                    int_me = np.where((xlow < x) & (x < xhigh) & (ylow < y)
                                      & (y < yhigh))[0]
                    if not int_me.size:
                        continue
                    mergesA.append(iA)
                    mergesB.append(iB)

                    # integrate the spot, this will change depending on data or simulation
                    #NOTE : adding in the data-spot center of mass here as well
                    totalCOM = np.zeros(3)  # NOTE: x,y,z
                    totalI = 0
                    totalNoise = 0
                    for ref_idx in int_me:
                        # TODO implement the spot intensity version here
                        # which fits the background plane!
                        totalI += Is[
                            ref_idx]  #rpp[pid][ref_idx]["intensity.sum.value"]
                        totalNoise += noise[ref_idx]**2
                        totalCOM += np.array(
                            rpp[pid][ref_idx]["xyzobs.px.value"])
                    totalCOM /= len(int_me)
                    totalNoise = np.sqrt(totalNoise)

                    PA = RA[iA]['intensity.sum.value']
                    PB = RB[iB]['intensity.sum.value']

                    # NOTE: added the simulated spot(s) center of mass
                    posA = RA[iA]['xyzobs.px.value']
                    posB = RB[iB]['xyzobs.px.value']
                    simCOM = np.mean([posA, posB], axis=0)

                    # get the hkl structure factor, and the sym equiv hkl
                    (horig, korig,
                     lorig) = HiA[iA]  # NOTE: same for A and B channels
                    h, k, l = setup_inputs.single_to_asu((horig, korig, lorig),
                                                         ano=False)
                    hAnom, kAnom, lAnom = setup_inputs.single_to_asu(
                        (horig, korig, lorig), ano=True)
                    if h == hAnom and k == kAnom and l == lAnom:
                        is_pos = True
                    else:
                        is_pos = False
                    DATA['is_pos'].append(is_pos)
                    DATA['horig'].append(horig)
                    DATA['korig'].append(korig)
                    DATA['lorig'].append(lorig)
                    DATA['h'].append(h)
                    DATA['k'].append(k)
                    DATA['l'].append(l)
                    DATA['hAnom'].append(hAnom)
                    DATA['kAnom'].append(kAnom)
                    DATA['lAnom'].append(lAnom)

                    DATA['D'].append(totalI)
                    DATA['Dnoise'].append(totalNoise)
                    DATA['PA'].append(PA)
                    DATA['PB'].append(PB)

                    DATA['pid'].append(pid)
                    DATA["Nstrong"].append(int_me.size)
                    DATA["iA"].append(iA)
                    DATA["iB"].append(iB)
                    all_int_me.append(int_me)

                    # NOTE: stash the sim-data distance (COM to COM)
                    DATA["delta_pix"].append(
                        distance.euclidean(totalCOM[:2], simCOM[:2]))
                    # this spot was both colors, overlapping
                    # find center of mass of all spots inside the integration box
                    # and find its distance to the center of mass of the simulation spots

            if inA:
                for iA, ref in enumerate(RA):
                    if iA in mergesA:
                        # this sim spot was already treated above
                        continue
                    x1A, x2A, y1A, y2A, _, _ = RA[iA][
                        'bbox']  # ['shoebox'].bbox
                    xlow = max((0, x1A - sz))
                    xhigh = min((panX, x2A + sz))
                    ylow = max((0, y1A - sz))
                    yhigh = min((panY, y2A + sz))
                    int_me = np.where((xlow < x) & (x < xhigh) & (ylow < y)
                                      & (y < yhigh))[0]
                    if not int_me.size:
                        continue

                    # NOTE: added in the total sim calc
                    totalCOM = np.zeros(3)
                    totalI = 0
                    totalNoise = 0
                    for ref_idx in int_me:
                        # TODO implement the spot intensity version here
                        # which fits the background plane!
                        totalI += Is[
                            ref_idx]  #rpp[pid][ref_idx]["intensity.sum.value"]
                        totalNoise += noise[ref_idx]**2
                        totalCOM += np.array(
                            rpp[pid][ref_idx]["xyzobs.px.value"])
                    totalCOM /= len(int_me)
                    totalNoise = np.sqrt(totalNoise)
                    PA = RA[iA]['intensity.sum.value']
                    PB = 0  # crucial ;)

                    # NOTE: added the simulated spot center of mass, for spotA
                    simCOM = np.array(RA[iA]['xyzobs.px.value'])

                    # get the hkl structure factor, and the sym equiv hkl
                    (horig, korig, lorig) = HiA[iA]
                    h, k, l = setup_inputs.single_to_asu((horig, korig, lorig),
                                                         ano=False)
                    hAnom, kAnom, lAnom = setup_inputs.single_to_asu(
                        (horig, korig, lorig), ano=True)
                    if h == hAnom and k == kAnom and l == lAnom:
                        is_pos = True
                    else:
                        is_pos = False
                    DATA['is_pos'].append(is_pos)
                    DATA['horig'].append(horig)
                    DATA['korig'].append(korig)
                    DATA['lorig'].append(lorig)
                    DATA['h'].append(h)
                    DATA['k'].append(k)
                    DATA['l'].append(l)
                    DATA['hAnom'].append(hAnom)
                    DATA['kAnom'].append(kAnom)
                    DATA['lAnom'].append(lAnom)

                    DATA['D'].append(totalI)
                    DATA['Dnoise'].append(totalNoise)
                    DATA['PA'].append(PA)
                    DATA['PB'].append(PB)

                    DATA['pid'].append(pid)
                    DATA["Nstrong"].append(int_me.size)
                    DATA["iA"].append(iA)
                    DATA["iB"].append(np.nan)
                    all_int_me.append(int_me)

                    # NOTE: stash the sim-data distance (COM to COM)
                    DATA["delta_pix"].append(
                        distance.euclidean(totalCOM[:2], simCOM[:2]))

            if inB:
                for iB, ref in enumerate(RB):
                    if iB in mergesB:
                        continue
                    x1B, x2B, y1B, y2B, _, _ = RB[iB]['bbox']  # shoebox'].bbox
                    xlow = max((0, x1B - sz))
                    xhigh = min((panX, x2B + sz))
                    ylow = max((0, y1B - sz))
                    yhigh = min((panY, y2B + sz))
                    # subimg = simsDataSum[pid][ylow:yhigh, xlow:xhigh]
                    # bg = 0
                    int_me = np.where((xlow < x) & (x < xhigh) & (ylow < y)
                                      & (y < yhigh))[0]
                    if not int_me.size:
                        continue

                    # NOTE: added in the total COM calc
                    totalCOM = np.zeros(3)
                    totalI = 0
                    totalNoise = 0
                    for ref_idx in int_me:
                        # TODO implement the spot intensity version here
                        # which fits the background plane!
                        totalI += Is[
                            ref_idx]  #rpp[pid][ref_idx]["intensity.sum.value"]
                        totalNoise += noise[ref_idx]**2
                        totalCOM += np.array(
                            rpp[pid][ref_idx]["xyzobs.px.value"])
                    totalCOM /= len(int_me)
                    totalNoise = np.sqrt(totalNoise)

                    PA = 0  # crucial ;)
                    PB = RB[iB]['intensity.sum.value']

                    # NOTE: added the simulated spot center of mass, for spotB only
                    simCOM = np.array(RB[iB]['xyzobs.px.value'])

                    # get the hkl structure factor, and the sym equiv hkl
                    (horig, korig, lorig) = HiB[iB]
                    h, k, l = setup_inputs.single_to_asu((horig, korig, lorig),
                                                         ano=False)
                    hAnom, kAnom, lAnom = setup_inputs.single_to_asu(
                        (horig, korig, lorig), ano=True)
                    if h == hAnom and k == kAnom and l == lAnom:
                        is_pos = True
                    else:
                        is_pos = False

                    DATA['is_pos'].append(is_pos)
                    DATA['horig'].append(horig)
                    DATA['korig'].append(korig)
                    DATA['lorig'].append(lorig)
                    DATA['h'].append(h)
                    DATA['k'].append(k)
                    DATA['l'].append(l)
                    DATA['hAnom'].append(hAnom)
                    DATA['kAnom'].append(kAnom)
                    DATA['lAnom'].append(lAnom)

                    DATA['D'].append(totalI)
                    DATA['Dnoise'].append(totalNoise)
                    DATA['PA'].append(PA)
                    DATA['PB'].append(PB)

                    DATA['pid'].append(pid)
                    DATA["Nstrong"].append(int_me.size)
                    DATA["iA"].append(np.nan)
                    DATA["iB"].append(iB)
                    all_int_me.append(int_me)

                    # NOTE: stash the sim-data distance (COM to COM)
                    DATA["delta_pix"].append(
                        distance.euclidean(totalCOM[:2], simCOM[:2]))

        df = pandas.DataFrame(DATA)
        df["run"] = run
        df["shot_idx"] = shot_idx
        df['LA'] = chanA_flux
        df["LB"] = chanB_flux
        df['K'] = FF[0]**2 * FLUX[0]
        df['nominal_gain'] = nom_gain
        all_dfs.append(df)
        print("Saved %d partial structure factor measurements in file %s" %
              (len(df), ofile))

    DF = pandas.concat(all_dfs)
    DF.to_pickle(ofile)