def prepare_dxtbx_models(self, setting_specific_ai, sg, isoform=None):

        from dxtbx.model import BeamFactory
        beam = BeamFactory.simple(wavelength=self.inputai.wavelength)

        from dxtbx.model import DetectorFactory
        detector = DetectorFactory.simple(
            sensor=DetectorFactory.sensor("PAD"),
            distance=setting_specific_ai.distance(),
            beam_centre=[
                setting_specific_ai.xbeam(),
                setting_specific_ai.ybeam()
            ],
            fast_direction="+x",
            slow_direction="+y",
            pixel_size=[self.pixel_size, self.pixel_size],
            image_size=[self.inputpd['size1'], self.inputpd['size1']],
        )

        direct = matrix.sqr(
            setting_specific_ai.getOrientation().direct_matrix())
        from dxtbx.model import Crystal
        crystal = Crystal(
            real_space_a=matrix.row(direct[0:3]),
            real_space_b=matrix.row(direct[3:6]),
            real_space_c=matrix.row(direct[6:9]),
            space_group_symbol=sg,
        )
        crystal.set_mosaicity(setting_specific_ai.getMosaicity())
        if isoform is not None:
            newB = matrix.sqr(isoform.fractionalization_matrix()).transpose()
            crystal.set_B(newB)

        from dxtbx.model import Experiment, ExperimentList
        experiments = ExperimentList()
        experiments.append(
            Experiment(beam=beam, detector=detector, crystal=crystal))

        print beam
        print detector
        print crystal
        return experiments
def test_crystal_model():
    real_space_a = matrix.col((10, 0, 0))
    real_space_b = matrix.col((0, 11, 0))
    real_space_c = matrix.col((0, 0, 12))
    model = Crystal(
        real_space_a=(10, 0, 0),
        real_space_b=(0, 11, 0),
        real_space_c=(0, 0, 12),
        space_group_symbol="P 1",
    )
    # This doesn't work as python class uctbx.unit_cell(uctbx_ext.unit_cell)
    # so C++ and python classes are different types
    # assert isinstance(model.get_unit_cell(), uctbx.unit_cell)
    assert model.get_unit_cell().parameters() == (10.0, 11.0, 12.0, 90.0, 90.0,
                                                  90.0)
    assert approx_equal(model.get_A(),
                        (1 / 10, 0, 0, 0, 1 / 11, 0, 0, 0, 1 / 12))
    assert approx_equal(
        matrix.sqr(model.get_A()).inverse(), (10, 0, 0, 0, 11, 0, 0, 0, 12))
    assert approx_equal(model.get_B(), model.get_A())
    assert approx_equal(model.get_U(), (1, 0, 0, 0, 1, 0, 0, 0, 1))
    assert approx_equal(model.get_real_space_vectors(),
                        (real_space_a, real_space_b, real_space_c))
    assert (model.get_crystal_symmetry().unit_cell().parameters() ==
            model.get_unit_cell().parameters())
    assert model.get_crystal_symmetry().space_group() == model.get_space_group(
    )

    model2 = Crystal(
        real_space_a=(10, 0, 0),
        real_space_b=(0, 11, 0),
        real_space_c=(0, 0, 12),
        space_group_symbol="P 1",
    )
    assert model == model2

    model2a = Crystal(model.get_A(), model.get_space_group())
    assert model == model2a

    model2b = Crystal(
        matrix.sqr(model.get_A()).inverse().elems,
        model.get_space_group().type().lookup_symbol(),
        reciprocal=False,
    )
    assert model == model2b

    # rotate 45 degrees about x-axis
    R1 = matrix.sqr((
        1,
        0,
        0,
        0,
        math.cos(math.pi / 4),
        -math.sin(math.pi / 4),
        0,
        math.sin(math.pi / 4),
        math.cos(math.pi / 4),
    ))
    # rotate 30 degrees about y-axis
    R2 = matrix.sqr((
        math.cos(math.pi / 6),
        0,
        math.sin(math.pi / 6),
        0,
        1,
        0,
        -math.sin(math.pi / 6),
        0,
        math.cos(math.pi / 6),
    ))
    # rotate 60 degrees about z-axis
    R3 = matrix.sqr((
        math.cos(math.pi / 3),
        -math.sin(math.pi / 3),
        0,
        math.sin(math.pi / 3),
        math.cos(math.pi / 3),
        0,
        0,
        0,
        1,
    ))
    R = R1 * R2 * R3
    model.set_U(R)
    # B is unchanged
    assert approx_equal(model.get_B(),
                        (1 / 10, 0, 0, 0, 1 / 11, 0, 0, 0, 1 / 12))
    assert approx_equal(model.get_U(), R)
    assert approx_equal(model.get_A(),
                        matrix.sqr(model.get_U()) * matrix.sqr(model.get_B()))
    a_, b_, c_ = model.get_real_space_vectors()
    assert approx_equal(a_, R * real_space_a)
    assert approx_equal(b_, R * real_space_b)
    assert approx_equal(c_, R * real_space_c)
    assert (str(model).replace("-0.0000", " 0.0000") == """\
Crystal:
    Unit cell: (10.000, 11.000, 12.000, 90.000, 90.000, 90.000)
    Space group: P 1
    U matrix:  {{ 0.4330, -0.7500,  0.5000},
                { 0.7891,  0.0474, -0.6124},
                { 0.4356,  0.6597,  0.6124}}
    B matrix:  {{ 0.1000,  0.0000,  0.0000},
                { 0.0000,  0.0909,  0.0000},
                { 0.0000,  0.0000,  0.0833}}
    A = UB:    {{ 0.0433, -0.0682,  0.0417},
                { 0.0789,  0.0043, -0.0510},
                { 0.0436,  0.0600,  0.0510}}
""")
    model.set_B((1 / 12, 0, 0, 0, 1 / 12, 0, 0, 0, 1 / 12))
    assert approx_equal(model.get_unit_cell().parameters(),
                        (12, 12, 12, 90, 90, 90))

    U = matrix.sqr((0.3455, -0.2589, -0.9020, 0.8914, 0.3909, 0.2293, 0.2933,
                    -0.8833, 0.3658))
    B = matrix.sqr((1 / 13, 0, 0, 0, 1 / 13, 0, 0, 0, 1 / 13))
    model.set_A(U * B)
    assert approx_equal(model.get_A(), U * B)
    assert approx_equal(model.get_U(), U, 1e-4)
    assert approx_equal(model.get_B(), B, 1e-5)

    model3 = Crystal(
        real_space_a=(10, 0, 0),
        real_space_b=(0, 11, 0),
        real_space_c=(0, 0, 12),
        space_group=sgtbx.space_group_info("P 222").group(),
    )
    assert model3.get_space_group().type().hall_symbol() == " P 2 2"
    assert model != model3
    #
    sgi_ref = sgtbx.space_group_info(number=230)
    model_ref = Crystal(
        real_space_a=(44, 0, 0),
        real_space_b=(0, 44, 0),
        real_space_c=(0, 0, 44),
        space_group=sgi_ref.group(),
    )
    assert approx_equal(model_ref.get_U(), (1, 0, 0, 0, 1, 0, 0, 0, 1))
    assert approx_equal(model_ref.get_B(),
                        (1 / 44, 0, 0, 0, 1 / 44, 0, 0, 0, 1 / 44))
    assert approx_equal(model_ref.get_A(), model_ref.get_B())
    assert approx_equal(model_ref.get_unit_cell().parameters(),
                        (44, 44, 44, 90, 90, 90))
    a_ref, b_ref, c_ref = map(matrix.col, model_ref.get_real_space_vectors())
    cb_op_to_primitive = sgi_ref.change_of_basis_op_to_primitive_setting()
    model_primitive = model_ref.change_basis(cb_op_to_primitive)
    cb_op_to_reference = (model_primitive.get_space_group().info().
                          change_of_basis_op_to_reference_setting())
    a_prim, b_prim, c_prim = map(matrix.col,
                                 model_primitive.get_real_space_vectors())
    assert (cb_op_to_primitive.as_abc() ==
            "-1/2*a+1/2*b+1/2*c,1/2*a-1/2*b+1/2*c,1/2*a+1/2*b-1/2*c")
    assert approx_equal(a_prim, -1 / 2 * a_ref + 1 / 2 * b_ref + 1 / 2 * c_ref)
    assert approx_equal(b_prim, 1 / 2 * a_ref - 1 / 2 * b_ref + 1 / 2 * c_ref)
    assert approx_equal(c_prim, 1 / 2 * a_ref + 1 / 2 * b_ref - 1 / 2 * c_ref)
    assert cb_op_to_reference.as_abc() == "b+c,a+c,a+b"
    assert approx_equal(a_ref, b_prim + c_prim)
    assert approx_equal(b_ref, a_prim + c_prim)
    assert approx_equal(c_ref, a_prim + b_prim)
    assert approx_equal(
        model_primitive.get_U(),
        [
            -0.5773502691896258,
            0.40824829046386285,
            0.7071067811865476,
            0.5773502691896257,
            -0.4082482904638631,
            0.7071067811865476,
            0.5773502691896257,
            0.8164965809277259,
            0.0,
        ],
    )
    assert approx_equal(
        model_primitive.get_B(),
        [
            0.0262431940540739,
            0.0,
            0.0,
            0.00927837023781507,
            0.02783511071344521,
            0.0,
            0.01607060866333063,
            0.01607060866333063,
            0.03214121732666125,
        ],
    )
    assert approx_equal(
        model_primitive.get_A(),
        (0, 1 / 44, 1 / 44, 1 / 44, 0, 1 / 44, 1 / 44, 1 / 44, 0),
    )
    assert approx_equal(
        model_primitive.get_unit_cell().parameters(),
        [
            38.1051177665153,
            38.1051177665153,
            38.1051177665153,
            109.47122063449069,
            109.47122063449069,
            109.47122063449069,
        ],
    )
    assert model_ref != model_primitive
    model_ref_recycled = model_primitive.change_basis(cb_op_to_reference)
    assert approx_equal(model_ref.get_U(), model_ref_recycled.get_U())
    assert approx_equal(model_ref.get_B(), model_ref_recycled.get_B())
    assert approx_equal(model_ref.get_A(), model_ref_recycled.get_A())
    assert approx_equal(
        model_ref.get_unit_cell().parameters(),
        model_ref_recycled.get_unit_cell().parameters(),
    )
    assert model_ref == model_ref_recycled

    uc = uctbx.unit_cell(
        (58.2567, 58.1264, 39.7093, 46.9077, 46.8612, 62.1055))
    sg = sgtbx.space_group_info(symbol="P1").group()
    cs = crystal.symmetry(unit_cell=uc, space_group=sg)
    cb_op_to_minimum = cs.change_of_basis_op_to_minimum_cell()
    # the reciprocal matrix
    B = matrix.sqr(uc.fractionalization_matrix()).transpose()
    U = random_rotation()
    direct_matrix = (U * B).inverse()
    model = Crystal(direct_matrix[:3],
                    direct_matrix[3:6],
                    direct_matrix[6:9],
                    space_group=sg)
    assert uc.is_similar_to(model.get_unit_cell())
    uc_minimum = uc.change_basis(cb_op_to_minimum)
    model_minimum = model.change_basis(cb_op_to_minimum)
    assert uc_minimum.is_similar_to(model_minimum.get_unit_cell())
    assert model_minimum != model
    model_minimum.update(model)
    assert model_minimum == model  # lgtm

    A_static = matrix.sqr(model.get_A())
    A_as_scan_points = [A_static]
    num_scan_points = 11
    for i in range(num_scan_points - 1):
        A_as_scan_points.append(
            A_as_scan_points[-1] *
            matrix.sqr(euler_angles.xyz_matrix(0.1, 0.2, 0.3)))
    model.set_A_at_scan_points(A_as_scan_points)
    model_minimum = model.change_basis(cb_op_to_minimum)
    assert model.num_scan_points == model_minimum.num_scan_points == num_scan_points
    M = matrix.sqr(cb_op_to_minimum.c_inv().r().transpose().as_double())
    M_inv = M.inverse()
    for i in range(num_scan_points):
        A_orig = matrix.sqr(model.get_A_at_scan_point(i))
        A_min = matrix.sqr(model_minimum.get_A_at_scan_point(i))
        assert approx_equal(A_min, A_orig * M_inv)
    assert model.get_unit_cell().parameters() == pytest.approx(
        (58.2567, 58.1264, 39.7093, 46.9077, 46.8612, 62.1055))
    uc = uctbx.unit_cell((10, 11, 12, 91, 92, 93))
    model.set_unit_cell(uc)
    assert model.get_unit_cell().parameters() == pytest.approx(uc.parameters())
Beispiel #3
0
    def load(self):

        # some parameters

        # NOTE: for reference, inside each h5 file there is
        #   [u'Amatrices', u'Hi', u'bboxes', u'h5_path']

        # get the total number of shots using worker 0
        if rank == 0:
            print("I am root. I am calculating total number of shots")
            h5s = [h5py_File(f, "r") for f in self.fnames]
            Nshots_per_file = [h["h5_path"].shape[0] for h in h5s]
            Nshots_tot = sum(Nshots_per_file)
            print("I am root. Total number of shots is %d" % Nshots_tot)

            print("I am root. I will divide shots amongst workers.")
            shot_tuples = []
            for i_f, fname in enumerate(self.fnames):
                fidx_shotidx = [(i_f, i_shot)
                                for i_shot in range(Nshots_per_file[i_f])]
                shot_tuples += fidx_shotidx

            from numpy import array_split
            print("I am root. Number of uniques = %d" % len(set(shot_tuples)))
            shots_for_rank = array_split(shot_tuples, size)

            # close the open h5s..
            for h in h5s:
                h.close()

        else:
            Nshots_tot = None
            shots_for_rank = None
            h5s = None

        #Nshots_tot = comm.bcast( Nshots_tot, root=0)
        if has_mpi:
            shots_for_rank = comm.bcast(shots_for_rank, root=0)
        #h5s = comm.bcast( h5s, root=0)  # pull in the open hdf5 files

        my_shots = shots_for_rank[rank]
        if self.Nload is not None:
            my_shots = my_shots[:self.Nload]

        # open the unique filenames for this rank
        # TODO: check max allowed pointers to open hdf5 file
        my_unique_fids = set([fidx for fidx, _ in my_shots])
        my_open_files = {
            fidx: h5py_File(self.fnames[fidx], "r")
            for fidx in my_unique_fids
        }
        Ntot = 0
        self.all_bbox_pixels = []
        for img_num, (fname_idx, shot_idx) in enumerate(my_shots):
            if img_num == args.Nmax:
                #print("Already processed maximum number images!")
                continue
            h = my_open_files[fname_idx]

            # load the dxtbx image data directly:
            npz_path = h["h5_path"][shot_idx]
            # NOTE take me out!
            if args.testmode:
                import os
                npz_path = os.path.basename(npz_path)
            img_handle = numpy_load(npz_path)
            img = img_handle["img"]

            if len(img.shape) == 2:  # if single panel>>
                img = np.array([img])

            #D = det_from_dict(img_handle["det"][()])
            B = beam_from_dict(img_handle["beam"][()])

            # get the indexed crystal Amatrix
            Amat = h["Amatrices"][shot_idx]
            amat_elems = list(sqr(Amat).inverse().elems)
            # real space basis vectors:
            a_real = amat_elems[:3]
            b_real = amat_elems[3:6]
            c_real = amat_elems[6:]

            # dxtbx indexed crystal model
            C = Crystal(a_real, b_real, c_real, "P43212")

            # change basis here ? Or maybe just average a/b
            a, b, c, _, _, _ = C.get_unit_cell().parameters()
            a_init = .5 * (a + b)
            c_init = c

            # shoe boxes where we expect spots
            bbox_dset = h["bboxes"]["shot%d" % shot_idx]
            n_bboxes_total = bbox_dset.shape[0]
            # is the shoe box within the resolution ring and does it have significant SNR (see filter_bboxes.py)
            is_a_keeper = h["bboxes"]["keepers%d" % shot_idx][()]

            # tilt plane to the background pixels in the shoe boxes
            tilt_abc_dset = h["tilt_abc"]["shot%d" % shot_idx]
            try:
                panel_ids_dset = h["panel_ids"]["shot%d" % shot_idx]
                has_panels = True
            except KeyError:
                has_panels = False

            # apply the filters:
            bboxes = [
                bbox_dset[i_bb] for i_bb in range(n_bboxes_total)
                if is_a_keeper[i_bb]
            ]
            tilt_abc = [
                tilt_abc_dset[i_bb] for i_bb in range(n_bboxes_total)
                if is_a_keeper[i_bb]
            ]
            if has_panels:
                panel_ids = [
                    panel_ids_dset[i_bb] for i_bb in range(n_bboxes_total)
                    if is_a_keeper[i_bb]
                ]
            else:
                panel_ids = [0] * len(tilt_abc)

            # how many pixels do we have
            tot_pix = [(j2 - j1) * (i2 - i1) for i1, i2, j1, j2 in bboxes]
            Ntot += sum(tot_pix)

            # actually load the pixels...
            #data_boxes = [ img[j1:j2, i1:i2] for i1,i2,j1,j2 in bboxes]

            # Here we will try a per-shot refinement of the unit cell and Umatrix, as well as ncells abc
            # and spot scale etc..

            # load some ground truth data from the simulation dumps (e.g. spectrum)
            h5_fname = h["h5_path"][shot_idx].replace(".npz", "")
            # NOTE remove me
            if args.testmode:
                h5_fname = os.path.basename(h5_fname)
            data = h5py_File(h5_fname, "r")

            tru = sqr(data["crystalA"][()]).inverse().elems
            a_tru = tru[:3]
            b_tru = tru[3:6]
            c_tru = tru[6:]
            C_tru = Crystal(a_tru, b_tru, c_tru, "P43212")
            try:
                angular_offset_init = compare_with_ground_truth(
                    a_tru, b_tru, c_tru, [C], symbol="P43212")[0]
            except Exception as err:
                print(
                    "Rank %d: Boo cant use the comparison w GT function: %s" %
                    (rank, err))

            fluxes = data["spectrum"][()]
            es = data["exposure_s"][()]
            fluxes *= es  # multiply by the exposure time
            spectrum = zip(wavelens, fluxes)
            # dont simulate when there are no photons!
            spectrum = [(wave, flux) for wave, flux in spectrum
                        if flux > self.flux_min]

            # make a unit cell manager that the refiner will use to track the B-matrix
            aa, _, cc, _, _, _ = C_tru.get_unit_cell().parameters()
            ucell_man = TetragonalManager(a=a_init, c=c_init)
            if args.startwithtruth:
                ucell_man = TetragonalManager(a=aa, c=cc)

            # create the sim_data instance that the refiner will use to run diffBragg
            # create a nanoBragg crystal
            nbcryst = nanoBragg_crystal()
            nbcryst.dxtbx_crystal = C
            if args.startwithtruth:
                nbcryst.dxtbx_crystal = C_tru

            nbcryst.thick_mm = 0.1
            nbcryst.Ncells_abc = 30, 30, 30
            nbcryst.miller_array = Fhkl_guess.as_amplitude_array()
            nbcryst.n_mos_domains = 1
            nbcryst.mos_spread_deg = 0.0

            # create a nanoBragg beam
            nbbeam = nanoBragg_beam()
            nbbeam.size_mm = 0.001
            nbbeam.unit_s0 = B.get_unit_s0()
            nbbeam.spectrum = spectrum

            # sim data instance
            SIM = SimData()
            SIM.detector = CSPAD
            #SIM.detector = D
            SIM.crystal = nbcryst
            SIM.beam = nbbeam
            SIM.panel_id = 0  # default

            spot_scale = 12
            if args.sad:
                spot_scale = 1
            SIM.instantiate_diffBragg(default_F=0, oversample=0)
            SIM.D.spot_scale = spot_scale

            img_in_photons = img / self.gain

            print("Rank %d, Starting refinement!" % rank)
            try:
                RUC = RefineAllMultiPanel(
                    spot_rois=bboxes,
                    abc_init=tilt_abc,
                    img=img_in_photons,  # NOTE this is now a multi panel image
                    SimData_instance=SIM,
                    plot_images=args.plot,
                    plot_residuals=args.residual,
                    ucell_manager=ucell_man)

                RUC.panel_ids = panel_ids
                RUC.multi_panel = True
                RUC.split_evaluation = args.split
                RUC.trad_conv = True
                RUC.refine_detdist = False
                RUC.refine_background_planes = False
                RUC.refine_Umatrix = True
                RUC.refine_Bmatrix = True
                RUC.refine_ncells = True
                RUC.use_curvatures = False  # args.curvatures
                RUC.calc_curvatures = True  #args.curvatures
                RUC.refine_crystal_scale = True
                RUC.refine_gain_fac = False
                RUC.plot_stride = args.stride
                RUC.poisson_only = False
                RUC.trad_conv_eps = 5e-3  # NOTE this is for single panel model
                RUC.max_calls = 300
                RUC.verbose = False
                RUC.use_rot_priors = True
                RUC.use_ucell_priors = True
                if args.verbose:
                    if rank == 0:  # only show refinement stats for rank 0
                        RUC.verbose = True
                RUC.run()
                if RUC.hit_break_to_use_curvatures:
                    RUC.use_curvatures = True
                    RUC.run(setup=False)
            except AssertionError as err:
                print(
                    "Rank %d, filename %s Hit assertion error during refinement: %s"
                    % (rank, data.filename, err))
                continue

            angle, ax = RUC.get_correction_misset(as_axis_angle_deg=True)
            if args.startwithtruth:
                C = Crystal(a_tru, b_tru, c_tru, "P43212")
            C.rotate_around_origin(ax, angle)
            C.set_B(RUC.get_refined_Bmatrix())
            a_ref, _, c_ref, _, _, _ = C.get_unit_cell().parameters()
            # compute missorientation with ground truth model
            try:
                angular_offset = compare_with_ground_truth(a_tru,
                                                           b_tru,
                                                           c_tru, [C],
                                                           symbol="P43212")[0]
                print(
                    "Rank %d, filename=%s, ang=%f, init_ang=%f, a=%f, init_a=%f, c=%f, init_c=%f"
                    % (rank, data.filename, angular_offset,
                       angular_offset_init, a_ref, a_init, c_ref, c_init))
            except Exception as err:
                print("Rank %d, filename=%s, error %s" %
                      (rank, data.filename, err))

            # free the memory from diffBragg instance
            RUC.S.D.free_all()
            del img  # not sure if needed here..
            del img_in_photons

            if args.testmode:
                exit()

            # peak at the memory usage of this rank
            mem = getrusage(RUSAGE_SELF).ru_maxrss  # peak mem usage in KB
            mem = mem / 1e6  # convert to GB

            if rank == 0:
                print "RANK 0: %.2g total pixels in %d/%d bboxes (file %d / %d); MemUsg=%2.2g GB" \
                      % (Ntot, len(bboxes), n_bboxes_total,  img_num+1, len(my_shots), mem)

            # TODO: accumulate all pixels
            #self.all_bbox_pixels += data_boxes

        for h in my_open_files.values():
            h.close()

        print("Rank %d; all subimages loaded!" % rank)