Exemplo n.º 1
0
from simtbx.nanoBragg import sim_data
from scitbx.matrix import sqr, rec
from cctbx import uctbx
from dxtbx.model import Crystal

ucell = (70, 60, 50, 90.0, 110, 90.0)
symbol = "C121"

a_real, b_real, c_real = sqr(uctbx.unit_cell(ucell).orthogonalization_matrix()).transpose().as_list_of_lists()
C = Crystal(a_real, b_real, c_real, symbol)

# random raotation
rotation = Rotation.random(num=1, random_state=101)[0]
Q = rec(rotation.as_quat(), n=(4, 1))
rot_ang, rot_axis = Q.unit_quaternion_as_axis_and_angle()
C.rotate_around_origin(rot_axis, rot_ang)

S = sim_data.SimData(use_default_crystal=True)
S.crystal.dxtbx_crystal = C
spectrum = S.beam.spectrum
wave, flux = spectrum[0]
Nwave = 5
waves = np.linspace(wave-wave*0.002, wave+wave*0.002, Nwave)
fluxes = np.ones(Nwave) * flux / Nwave

lambda0_GT = 0
lambda1_GT = 1

S.beam.spectrum = list(zip(waves, fluxes))
S.detector = sim_data.SimData.simple_detector(180, 0.1, (1024, 1024))
S.instantiate_diffBragg(verbose=0, oversample=0, auto_set_spotscale=True)
Exemplo n.º 2
0
    def load(self):

        # some parameters

        # NOTE: for reference, inside each h5 file there is
        #   [u'Amatrices', u'Hi', u'bboxes', u'h5_path']

        # get the total number of shots using worker 0
        if rank == 0:
            print("I am root. I am calculating total number of shots")
            h5s = [h5py_File(f, "r") for f in self.fnames]
            Nshots_per_file = [h["h5_path"].shape[0] for h in h5s]
            Nshots_tot = sum(Nshots_per_file)
            print("I am root. Total number of shots is %d" % Nshots_tot)

            print("I am root. I will divide shots amongst workers.")
            shot_tuples = []
            for i_f, fname in enumerate(self.fnames):
                fidx_shotidx = [(i_f, i_shot)
                                for i_shot in range(Nshots_per_file[i_f])]
                shot_tuples += fidx_shotidx

            from numpy import array_split
            print("I am root. Number of uniques = %d" % len(set(shot_tuples)))
            shots_for_rank = array_split(shot_tuples, size)

            # close the open h5s..
            for h in h5s:
                h.close()

        else:
            Nshots_tot = None
            shots_for_rank = None
            h5s = None

        #Nshots_tot = comm.bcast( Nshots_tot, root=0)
        if has_mpi:
            shots_for_rank = comm.bcast(shots_for_rank, root=0)
        #h5s = comm.bcast( h5s, root=0)  # pull in the open hdf5 files

        my_shots = shots_for_rank[rank]
        if self.Nload is not None:
            my_shots = my_shots[:self.Nload]

        # open the unique filenames for this rank
        # TODO: check max allowed pointers to open hdf5 file
        my_unique_fids = set([fidx for fidx, _ in my_shots])
        my_open_files = {
            fidx: h5py_File(self.fnames[fidx], "r")
            for fidx in my_unique_fids
        }
        Ntot = 0
        self.all_bbox_pixels = []
        for img_num, (fname_idx, shot_idx) in enumerate(my_shots):
            if img_num == args.Nmax:
                #print("Already processed maximum number images!")
                continue
            h = my_open_files[fname_idx]

            # load the dxtbx image data directly:
            npz_path = h["h5_path"][shot_idx]
            # NOTE take me out!
            if args.testmode:
                import os
                npz_path = os.path.basename(npz_path)
            img_handle = numpy_load(npz_path)
            img = img_handle["img"]

            if len(img.shape) == 2:  # if single panel>>
                img = np.array([img])

            #D = det_from_dict(img_handle["det"][()])
            B = beam_from_dict(img_handle["beam"][()])

            # get the indexed crystal Amatrix
            Amat = h["Amatrices"][shot_idx]
            amat_elems = list(sqr(Amat).inverse().elems)
            # real space basis vectors:
            a_real = amat_elems[:3]
            b_real = amat_elems[3:6]
            c_real = amat_elems[6:]

            # dxtbx indexed crystal model
            C = Crystal(a_real, b_real, c_real, "P43212")

            # change basis here ? Or maybe just average a/b
            a, b, c, _, _, _ = C.get_unit_cell().parameters()
            a_init = .5 * (a + b)
            c_init = c

            # shoe boxes where we expect spots
            bbox_dset = h["bboxes"]["shot%d" % shot_idx]
            n_bboxes_total = bbox_dset.shape[0]
            # is the shoe box within the resolution ring and does it have significant SNR (see filter_bboxes.py)
            is_a_keeper = h["bboxes"]["keepers%d" % shot_idx][()]

            # tilt plane to the background pixels in the shoe boxes
            tilt_abc_dset = h["tilt_abc"]["shot%d" % shot_idx]
            try:
                panel_ids_dset = h["panel_ids"]["shot%d" % shot_idx]
                has_panels = True
            except KeyError:
                has_panels = False

            # apply the filters:
            bboxes = [
                bbox_dset[i_bb] for i_bb in range(n_bboxes_total)
                if is_a_keeper[i_bb]
            ]
            tilt_abc = [
                tilt_abc_dset[i_bb] for i_bb in range(n_bboxes_total)
                if is_a_keeper[i_bb]
            ]
            if has_panels:
                panel_ids = [
                    panel_ids_dset[i_bb] for i_bb in range(n_bboxes_total)
                    if is_a_keeper[i_bb]
                ]
            else:
                panel_ids = [0] * len(tilt_abc)

            # how many pixels do we have
            tot_pix = [(j2 - j1) * (i2 - i1) for i1, i2, j1, j2 in bboxes]
            Ntot += sum(tot_pix)

            # actually load the pixels...
            #data_boxes = [ img[j1:j2, i1:i2] for i1,i2,j1,j2 in bboxes]

            # Here we will try a per-shot refinement of the unit cell and Umatrix, as well as ncells abc
            # and spot scale etc..

            # load some ground truth data from the simulation dumps (e.g. spectrum)
            h5_fname = h["h5_path"][shot_idx].replace(".npz", "")
            # NOTE remove me
            if args.testmode:
                h5_fname = os.path.basename(h5_fname)
            data = h5py_File(h5_fname, "r")

            tru = sqr(data["crystalA"][()]).inverse().elems
            a_tru = tru[:3]
            b_tru = tru[3:6]
            c_tru = tru[6:]
            C_tru = Crystal(a_tru, b_tru, c_tru, "P43212")
            try:
                angular_offset_init = compare_with_ground_truth(
                    a_tru, b_tru, c_tru, [C], symbol="P43212")[0]
            except Exception as err:
                print(
                    "Rank %d: Boo cant use the comparison w GT function: %s" %
                    (rank, err))

            fluxes = data["spectrum"][()]
            es = data["exposure_s"][()]
            fluxes *= es  # multiply by the exposure time
            spectrum = zip(wavelens, fluxes)
            # dont simulate when there are no photons!
            spectrum = [(wave, flux) for wave, flux in spectrum
                        if flux > self.flux_min]

            # make a unit cell manager that the refiner will use to track the B-matrix
            aa, _, cc, _, _, _ = C_tru.get_unit_cell().parameters()
            ucell_man = TetragonalManager(a=a_init, c=c_init)
            if args.startwithtruth:
                ucell_man = TetragonalManager(a=aa, c=cc)

            # create the sim_data instance that the refiner will use to run diffBragg
            # create a nanoBragg crystal
            nbcryst = nanoBragg_crystal()
            nbcryst.dxtbx_crystal = C
            if args.startwithtruth:
                nbcryst.dxtbx_crystal = C_tru

            nbcryst.thick_mm = 0.1
            nbcryst.Ncells_abc = 30, 30, 30
            nbcryst.miller_array = Fhkl_guess.as_amplitude_array()
            nbcryst.n_mos_domains = 1
            nbcryst.mos_spread_deg = 0.0

            # create a nanoBragg beam
            nbbeam = nanoBragg_beam()
            nbbeam.size_mm = 0.001
            nbbeam.unit_s0 = B.get_unit_s0()
            nbbeam.spectrum = spectrum

            # sim data instance
            SIM = SimData()
            SIM.detector = CSPAD
            #SIM.detector = D
            SIM.crystal = nbcryst
            SIM.beam = nbbeam
            SIM.panel_id = 0  # default

            spot_scale = 12
            if args.sad:
                spot_scale = 1
            SIM.instantiate_diffBragg(default_F=0, oversample=0)
            SIM.D.spot_scale = spot_scale

            img_in_photons = img / self.gain

            print("Rank %d, Starting refinement!" % rank)
            try:
                RUC = RefineAllMultiPanel(
                    spot_rois=bboxes,
                    abc_init=tilt_abc,
                    img=img_in_photons,  # NOTE this is now a multi panel image
                    SimData_instance=SIM,
                    plot_images=args.plot,
                    plot_residuals=args.residual,
                    ucell_manager=ucell_man)

                RUC.panel_ids = panel_ids
                RUC.multi_panel = True
                RUC.split_evaluation = args.split
                RUC.trad_conv = True
                RUC.refine_detdist = False
                RUC.refine_background_planes = False
                RUC.refine_Umatrix = True
                RUC.refine_Bmatrix = True
                RUC.refine_ncells = True
                RUC.use_curvatures = False  # args.curvatures
                RUC.calc_curvatures = True  #args.curvatures
                RUC.refine_crystal_scale = True
                RUC.refine_gain_fac = False
                RUC.plot_stride = args.stride
                RUC.poisson_only = False
                RUC.trad_conv_eps = 5e-3  # NOTE this is for single panel model
                RUC.max_calls = 300
                RUC.verbose = False
                RUC.use_rot_priors = True
                RUC.use_ucell_priors = True
                if args.verbose:
                    if rank == 0:  # only show refinement stats for rank 0
                        RUC.verbose = True
                RUC.run()
                if RUC.hit_break_to_use_curvatures:
                    RUC.use_curvatures = True
                    RUC.run(setup=False)
            except AssertionError as err:
                print(
                    "Rank %d, filename %s Hit assertion error during refinement: %s"
                    % (rank, data.filename, err))
                continue

            angle, ax = RUC.get_correction_misset(as_axis_angle_deg=True)
            if args.startwithtruth:
                C = Crystal(a_tru, b_tru, c_tru, "P43212")
            C.rotate_around_origin(ax, angle)
            C.set_B(RUC.get_refined_Bmatrix())
            a_ref, _, c_ref, _, _, _ = C.get_unit_cell().parameters()
            # compute missorientation with ground truth model
            try:
                angular_offset = compare_with_ground_truth(a_tru,
                                                           b_tru,
                                                           c_tru, [C],
                                                           symbol="P43212")[0]
                print(
                    "Rank %d, filename=%s, ang=%f, init_ang=%f, a=%f, init_a=%f, c=%f, init_c=%f"
                    % (rank, data.filename, angular_offset,
                       angular_offset_init, a_ref, a_init, c_ref, c_init))
            except Exception as err:
                print("Rank %d, filename=%s, error %s" %
                      (rank, data.filename, err))

            # free the memory from diffBragg instance
            RUC.S.D.free_all()
            del img  # not sure if needed here..
            del img_in_photons

            if args.testmode:
                exit()

            # peak at the memory usage of this rank
            mem = getrusage(RUSAGE_SELF).ru_maxrss  # peak mem usage in KB
            mem = mem / 1e6  # convert to GB

            if rank == 0:
                print "RANK 0: %.2g total pixels in %d/%d bboxes (file %d / %d); MemUsg=%2.2g GB" \
                      % (Ntot, len(bboxes), n_bboxes_total,  img_num+1, len(my_shots), mem)

            # TODO: accumulate all pixels
            #self.all_bbox_pixels += data_boxes

        for h in my_open_files.values():
            h.close()

        print("Rank %d; all subimages loaded!" % rank)