예제 #1
0
 def __init__(self, cell, spacegroup, dmin, anomalous):
     """
     Parameters
     ----------
     cell : gemmi.UnitCell
         A gemmi.UnitCell instance
     spacegroup : gemmi.SpaceGroup
         A gemmi.SpaceGroup instance
     dmin : float
         The maximum resolution in Angstroms
     anomalous : bool
         Whether the ASU includes Friedel minus acentric reflections
     """
     self.cell = cell
     self.spacegroup = spacegroup
     self.dmin = dmin
     self.anomalous = anomalous
     self.Hall = rs.utils.generate_reciprocal_asu(self.cell,
                                                  self.spacegroup,
                                                  self.dmin, self.anomalous)
     h, k, l = self.Hall.T
     lookup_table = rs.DataSet(
         {
             'H': h,
             'K': k,
             'L': l,
             'id': np.arange(len(h))
         },
         cell=cell,
         spacegroup=spacegroup,
     ).compute_multiplicity().label_centrics().compute_dHKL()
     self.lookup_table = lookup_table
예제 #2
0
def test_constructor_dataset(data_fmodel, spacegroup, cell, merged):
    """Test DataSet.__init__() when called with a DataSet"""
    result = rs.DataSet(data_fmodel,
                        spacegroup=spacegroup,
                        cell=cell,
                        merged=merged)
    assert_frame_equal(result, data_fmodel)

    if merged is not None:
        assert result.merged == merged
    else:
        assert result.merged == True

    # Ensure provided values take precedence
    if spacegroup:
        assert result.spacegroup.xhm() == "P 1"
        assert isinstance(result.spacegroup, gemmi.SpaceGroup)
    else:
        assert result.spacegroup == data_fmodel.spacegroup

    if cell:
        if isinstance(cell, gemmi.UnitCell):
            assert result.cell == cell
        else:
            assert result.cell.a == cell[0]
        assert isinstance(result.cell, gemmi.UnitCell)
    else:
        assert result.cell == data_fmodel.cell
예제 #3
0
def test_remove_absences(inplace, hkl_index):
    """Test DataSet.remove_absences()"""
    params = (34., 45., 98., 90., 90., 90.)
    cell = gemmi.UnitCell(*params)
    sg_1 = gemmi.SpaceGroup(1)
    sg_19 = gemmi.SpaceGroup(19)
    Hall = rs.utils.generate_reciprocal_asu(cell, sg_1, 5., anomalous=False)
    h, k, l = Hall.T
    absent = rs.utils.is_absent(Hall, sg_19)
    ds = rs.DataSet({
        'H': h,
        'K': k,
        'L': l,
        'I': np.ones(len(h)),
    },
                    spacegroup=sg_19,
                    cell=cell).infer_mtz_dtypes()
    if hkl_index:
        ds.set_index(['H', 'K', 'L'], inplace=True)

    ds_test = ds.remove_absences(inplace=inplace)
    ds_true = ds[~ds.label_absences().ABSENT]

    assert len(ds_test) == len(Hall) - absent.sum()
    assert np.array_equal(ds_test.get_hkls(), ds_true.get_hkls())

    if inplace:
        assert id(ds_test) == id(ds)
    else:
        assert id(ds_test) != id(ds)
예제 #4
0
def test_from_gemmi(data_gemmi):
    """Test DataSet.from_gemmi()"""
    result = rs.DataSet.from_gemmi(data_gemmi)
    expected = rs.DataSet(data_gemmi)
    assert_frame_equal(result, expected)
    assert result.spacegroup == expected.spacegroup
    assert result.cell == expected.cell
    assert result._index_dtypes == expected._index_dtypes
예제 #5
0
def dataset_hkl():
    """
    Build DataSet for testing containing only Miller indices
    """
    hmin, hmax = -5, 5
    H = np.mgrid[hmin:hmax+1:2,hmin:hmax+1:2,hmin:hmax+1:2].reshape((3, -1)).T
    dataset = rs.DataSet({"H": H[:, 0], "K": H[:, 1], "L": H[:, 2]})
    dataset.set_index(["H", "K", "L"], inplace=True)
    return dataset
예제 #6
0
def test_constructor_gemmi(data_gemmi, spacegroup, cell):
    """Test DataSet.__init__() when called with a DataSet"""
    result = rs.DataSet(data_gemmi, spacegroup=spacegroup, cell=cell)

    # Ensure provided values take precedence
    if spacegroup:
        assert result.spacegroup == spacegroup
    else:
        assert result.spacegroup == data_gemmi.spacegroup

    if cell:
        assert result.cell == cell
    else:
        assert result.cell == data_gemmi.cell
예제 #7
0
    def get_predictions(self, model, inputs=None):
        """ 
        Extract results from a surrogate_posterior.

        Parameters
        ----------
        model : VariationalMergingModel
            A merging model from careless
        inputs : tuple (optional)
            Inputs for which to make the predictions if None, self.inputs is used.

        Returns
        -------
        predictions : tuple
            A tuple of rs.DataSet objects containing the predictions for each 
            ReciprocalASU contained in self.asu_collection
        """
        if inputs is None:
            inputs = self.inputs

        refl_id = BaseModel.get_refl_id(inputs)
        iobs = BaseModel.get_intensities(inputs).flatten()
        sig_iobs = BaseModel.get_uncertainties(inputs).flatten()
        asu_id, H = self.asu_collection.to_asu_id_and_miller_index(refl_id)
        #ipred = model(inputs)
        ipred, sigipred = model.prediction_mean_stddev(inputs)

        h, k, l = H.T
        results = ()
        for i, asu in enumerate(self.asu_collection):
            idx = asu_id == i
            idx = idx.flatten()
            output = rs.DataSet(
                {
                    'H': h[idx],
                    'K': k[idx],
                    'L': l[idx],
                    'Iobs': iobs[idx],
                    'SigIobs': sig_iobs[idx],
                    'Ipred': ipred[idx],
                    'SigIpred': sigipred[idx],
                },
                cell=asu.cell,
                spacegroup=asu.spacegroup,
                merged=False,
            ).infer_mtz_dtypes().set_index(['H', 'K', 'L'])
            results += (output, )
        return results
예제 #8
0
    def expand_to_p1(self):
        """
        Generates all symmetrically equivalent reflections. The spacegroup 
        symmetry is set to P1.
        
        Returns
        -------
        DataSet
        """
        if not self.merged:
            raise ValueError(
                "This function is only applicable for merged DataSets")
        if not in_asu(self.get_hkls(), spacegroup=self.spacegroup).all():
            raise ValueError(
                "This function is only  applicable for reflection data in the reciprocal ASU and anomalous data in a two-column (unstacked) format"
            )

        p1 = rs.DataSet(spacegroup=self.spacegroup, cell=self.cell)

        # Get all symops, in ascending order by ISYM
        groupops = self.spacegroup.operations()
        allops = [op for op in groupops for op in (op, op.negated())]

        # Apply each symop and drop duplicates with higher ISYM
        for isym, op in enumerate(allops, 1):
            ds = self.copy()
            ds["M/ISYM"] = isym
            ds["M/ISYM"] = ds["M/ISYM"].astype("M/ISYM")
            p1 = p1.append(ds.hkl_to_observed(m_isym="M/ISYM"))
            p1.drop_duplicates(subset=["H", "K", "L"], inplace=True)

        # Restrict to p1 ASU
        p1.spacegroup = gemmi.SpaceGroup(1)
        p1 = p1.loc[in_asu(p1.get_hkls(), spacegroup=p1.spacegroup)]

        return p1
예제 #9
0
def run_careless(parser):
    # We defer all inputs to make sure the parser has priority in modifying tf parameters
    import tensorflow as tf
    import numpy as np
    import reciprocalspaceship as rs
    from careless.io.manager import DataManager
    from careless.io.formatter import MonoFormatter,LaueFormatter
    from careless.models.base import BaseModel
    from careless.models.merging.surrogate_posteriors import TruncatedNormal
    from careless.models.merging.variational import VariationalMergingModel
    from careless.models.scaling.image import HybridImageScaler,ImageScaler
    from careless.models.scaling.nn import MLPScaler

    if parser.type == 'poly':
        df = LaueFormatter.from_parser(parser)
    elif parser.type == 'mono':
        df = MonoFormatter.from_parser(parser)


    inputs,rac = df.format_files(parser.reflection_files)
    dm = DataManager(inputs, rac, parser=parser)

    if parser.test_fraction is not None:
        train,test = dm.split_data_by_refl(parser.test_fraction)
    else:
        train,test = dm.inputs,None

    model = dm.build_model()

    history = model.train_model(
        tuple(map(tf.convert_to_tensor, train)),
        parser.iterations,
        message="Training",
    )

    for i,ds in enumerate(dm.get_results(model.surrogate_posterior, inputs=train)):
        filename = parser.output_base + f'_{i}.mtz'
        ds.write_mtz(filename)

    filename = parser.output_base + f'_history.csv'
    history = rs.DataSet(history).to_csv(filename, index_label='step')

    model.save_weights(parser.output_base + '_weights')
    import pickle
    with open(parser.output_base + "_data_manager.pickle", "wb") as out:
        pickle.dump(dm, out)

    predictions_data = None
    if test is not None:
        for file_id, (ds_train, ds_test) in enumerate(zip(
                dm.get_predictions(model, train),
                dm.get_predictions(model, test),
                )):
            ds_train['test'] = rs.DataSeries(0, index=ds_train.index, dtype='I')
            ds_test['test']  = rs.DataSeries(1, index=ds_test.index, dtype='I')

            filename = parser.output_base + f'_predictions_{file_id}.mtz'
            ds_train.append(ds_test).write_mtz(filename)
    else:
        for file_id, ds_train in enumerate(dm.get_predictions(model, train)):
            ds_train['test'] = rs.DataSeries(0, index=ds_train.index, dtype='I')

            filename = parser.output_base + f'_predictions_{file_id}.mtz'
            ds_train.write_mtz(filename)

    if parser.merge_half_datasets:
        scaling_model = model.scaling_model
        scaling_model.trainable = False
        xval_data = [None] * len(dm.asu_collection)
        for repeat in range(parser.half_dataset_repeats):
            for half_id, half in enumerate(dm.split_data_by_image()):
                model = dm.build_model(scaling_model=scaling_model)
                history = model.train_model(
                    tuple(map(tf.convert_to_tensor, half)), 
                    parser.iterations,
                    message=f"Merging repeat {repeat+1} half {half_id+1}",
                )

                for file_id,ds in enumerate(dm.get_results(model.surrogate_posterior, inputs=half)):
                    ds['repeat'] = rs.DataSeries(repeat, index=ds.index, dtype='I')
                    ds['half'] = rs.DataSeries(half_id, index=ds.index, dtype='I')
                    if xval_data[file_id] is None:
                        xval_data[file_id] = ds
                    else:
                        xval_data[file_id] = xval_data[file_id].append(ds)

        for file_id, ds in enumerate(xval_data):
            filename = parser.output_base + f'_xval_{file_id}.mtz'
            ds.write_mtz(filename)

    if parser.embed:
        from IPython import embed
        embed(colors='Linux')
예제 #10
0
unit_cell1 = cryst1.get_unit_cell()
refls1 = reflection_table.from_file(refl_file1)
elist2 = ExperimentListFactory.from_json_file(expt_file2)
cryst2 = elist2.crystals()[0]
unit_cell2 = cryst2.get_unit_cell()
refls2 = reflection_table.from_file(refl_file2)

# Remove reflections not used in refinement
refls2 = refls2.select(refls2.get_flags(refls2.flags.used_in_refinement))

print('generating DIALS dataframes')
dials_df1 = rs.DataSet(
    {
        'X': refls1['xyzobs.px.value'].parts()[0].as_numpy_array(),
        'Y': refls1['xyzobs.px.value'].parts()[1].as_numpy_array(),
        'Wavelength': refls1['Wavelength'].as_numpy_array(),
        'BATCH': refls1['imageset_id'].as_numpy_array(),
    },
    cell=gemmi.UnitCell(*unit_cell1.parameters()),
    spacegroup=gemmi.SpaceGroup(cryst1.get_space_group().type(
    ).universal_hermann_mauguin_symbol())).infer_mtz_dtypes()
dials_df2 = rs.DataSet(
    {
        'X': refls2['xyzobs.px.value'].parts()[0].as_numpy_array(),
        'Y': refls2['xyzobs.px.value'].parts()[1].as_numpy_array(),
        'Wavelength': refls2['Wavelength'].as_numpy_array(),
        'BATCH': refls2['imageset_id'].as_numpy_array(),
    },
    cell=gemmi.UnitCell(*unit_cell2.parameters()),
    spacegroup=gemmi.SpaceGroup(cryst2.get_space_group().type(
    ).universal_hermann_mauguin_symbol())).infer_mtz_dtypes()
from dials.array_family.flex import reflection_table
elist = ExperimentListFactory.from_json_file(expt_file)
cryst = elist.crystals()[0]
unit_cell = cryst.get_unit_cell()
spacegroup = gemmi.SpaceGroup(
    cryst.get_space_group().type().universal_hermann_mauguin_symbol())
cell = gemmi.UnitCell(*unit_cell.parameters())
refls = reflection_table.from_file(refl_file)
refls = refls.select(refls.get_flags(refls.flags.used_in_refinement))

print('generating DIALS dataframe')
dials_df = rs.DataSet(
    {
        'X': refls['xyzobs.px.value'].parts()[0].as_numpy_array(),
        'Y': refls['xyzobs.px.value'].parts()[1].as_numpy_array(),
        'Xcal': refls['xyzcal.px'].parts()[0].as_numpy_array(),
        'Ycal': refls['xyzcal.px'].parts()[1].as_numpy_array(),
        'BATCH': refls['imageset_id'].as_numpy_array(),
    },
    cell=cell,
    spacegroup=spacegroup).infer_mtz_dtypes()

print('initializing metrics')
dials_rmsds = np.zeros(len(precog_rmsds))
rmsd_diff = np.zeros(len(precog_rmsds))

# Iterate by frame and match HKLs, seeing what percentage are correct
for i in trange(len(precog_rmsds)):
    im_dia = dials_df[dials_df['BATCH'] == i]

    # Get XY positions for refls
    xy = im_dia[['X', 'Y']].to_numpy(float)
예제 #12
0
def compute_redundancy(hobs,
                       cell,
                       spacegroup,
                       full_asu=True,
                       anomalous=False,
                       dmin=None):
    """
    Compute the multiplicity of all valid reflections in the reciprocal ASU.

    Parameters
    ----------
    hobs : np.array(int)
        An n by 3 array of observed miller indices which are not necessarily
        in the reciprocal asymmetric unit
    spacegroup : gemmi.SpaceGroup
        A gemmi SpaceGroup object.
    cell : gemmi.UnitCell
        A gemmi UnitCell object.
    full_asu : bool (optional)
        Include all the reflections in the calculation irrespective of whether they were
        observed.
    anomalous : bool (optional)
        Whether or not the data are anomalous.
    dmin : float (optional)
        If no dmin is supplied, the maximum resolution reflection will be used.

    Returns
    -------
    hunique : np.array (int32)
        An n by 3 array of unique miller indices from hobs.
    counts : np.array (int32)
        A length n array of counts for each miller index in hunique.
    """
    hobs = hobs[~rs.utils.is_absent(hobs, spacegroup)]
    dhkl = rs.utils.compute_dHKL(hobs, cell)
    if dmin is None:
        dmin = dhkl.min()
    hobs = hobs[dhkl >= dmin]
    decimals = 5.  #Round after this many decimals
    dmin = np.floor(dmin * 10**decimals) * 10**-decimals
    hobs, isym = rs.utils.hkl_to_asu(hobs, spacegroup)
    if anomalous:
        fminus = isym % 2 == 0
        hobs[fminus] = -hobs[fminus]

    mult = rs.DataSet(
        {
            'H': hobs[:, 0],
            'K': hobs[:, 1],
            'L': hobs[:, 2],
            'Count': np.ones(len(hobs)),
        },
        cell=cell,
        spacegroup=spacegroup).groupby(['H', 'K', 'L']).sum()

    if full_asu:
        hall = rs.utils.generate_reciprocal_asu(cell, spacegroup, dmin,
                                                anomalous)

        ASU = rs.DataSet(
            {
                'H': hall[:, 0],
                'K': hall[:, 1],
                'L': hall[:, 2],
                'Count': np.zeros(len(hall)),
            },
            cell=cell,
            spacegroup=spacegroup).set_index(['H', 'K', 'L'])
        ASU = ASU.loc[ASU.index.difference(mult.index)]
        mult = mult.append(ASU)

    mult = mult.sort_index()
    return mult.get_hkls(), mult['Count'].to_numpy(np.int32)
예제 #13
0
cryst = elist.crystals()[0]
unit_cell = cryst.get_unit_cell()
precog_df.spacegroup = gemmi.SpaceGroup(
    cryst.get_space_group().type().universal_hermann_mauguin_symbol())
precog_df.cell = gemmi.UnitCell(*unit_cell.parameters())
refls = reflection_table.from_file(refl_file)

print('generating DIALS dataframe')
dials_df = rs.DataSet(
    {
        'X': refls['xyzobs.px.value'].parts()[0].as_numpy_array(),
        'Y': refls['xyzobs.px.value'].parts()[1].as_numpy_array(),
        'H':
        refls['miller_index'].as_vec3_double().parts()[0].as_numpy_array(),
        'K':
        refls['miller_index'].as_vec3_double().parts()[1].as_numpy_array(),
        'L':
        refls['miller_index'].as_vec3_double().parts()[2].as_numpy_array(),
        'Wavelength': refls['Wavelength'].as_numpy_array(),
        'BATCH': refls['imageset_id'].as_numpy_array(),
    },
    cell=precog_df.cell,
    spacegroup=precog_df.spacegroup).infer_mtz_dtypes()

print('initializing metrics')
percent_correct = np.zeros(len(elist))
percent_outliers = np.zeros(len(elist))
percent_misindexed = np.zeros(len(elist))
nspots = np.zeros(len(elist))
nmatch = np.zeros(len(elist))
예제 #14
0
def test_constructor_empty():
    """Test DataSet.__init__()"""
    result = rs.DataSet()
    assert len(result) == 0
    assert result.spacegroup is None
    assert result.cell is None
예제 #15
0
new_expt_filename = args.output + '.expt'
new_refl_filename = args.output + '.refl'

# Get experiments
expts = ExperimentListFactory.from_json_file(expt_filename)
new_expts = ExperimentList()

# Initialize flex tables for refl files
refl_input = flex.reflection_table().from_file(refl_filename)
refl_output = refl_input.copy()
refl_output["id"] = flex.int([-1] * len(refl_output))

# Initialize data frame
dials_df = rs.DataSet({
    'Wavelength': refl_input['Wavelength'],
    'ID': refl_input['id'],
    'new_ID': [-1] * len(refl_input)
})  #.infer_mtz_dtypes()

# Generate beams per reflection
print(f'Number of rows: {len(dials_df)}')
for i, refl in tqdm(dials_df.iterrows()):
    # New beam per reflection
    expt = expts[refl['ID'][i]]
    temp = expt.beam.get_s0()
    new_expt = expt
    new_expt.beam = deepcopy(expt.beam)
    new_expt.beam.set_wavelength(refl['Wavelength'][i])
    s0 = (expt.beam.get_s0() /
          np.linalg.norm(expt.beam.get_s0())) / new_expt.beam.get_wavelength()
    new_expt.beam.set_s0(s0)
예제 #16
0
    def get_results(self,
                    surrogate_posterior,
                    inputs=None,
                    output_parameters=True):
        """ 
        Extract results from a surrogate_posterior.

        Parameters
        ----------
        surrogate_posterior : tfd.Distribution
            A tensorflow_probability distribution or similar object with `mean` and `stddev` methods
        inputs : tuple (optional)
            Optionally use a different object from self.inputs to compute the redundancy of reflections.
        output_parameters : bool (optional)
            If True, output the parameters of the surrogate distribution in addition to the 
            moments. 

        Returns
        -------
        results : tuple
            A tuple of rs.DataSet objects containing the results corresponding to each 
            ReciprocalASU contained in self.asu_collection
        """
        if inputs is None:
            inputs = self.inputs
        F = surrogate_posterior.mean().numpy()
        SigF = surrogate_posterior.stddev().numpy()
        params = None
        if output_parameters:
            params = {}
            for k in sorted(surrogate_posterior.parameter_properties()):
                v = surrogate_posterior.parameters[k]
                numpify = lambda x: tf.convert_to_tensor(x).numpy()
                params[k] = numpify(v).flatten() * np.ones(len(F),
                                                           dtype='float32')
        asu_id, H = self.asu_collection.to_asu_id_and_miller_index(
            np.arange(len(F)))
        h, k, l = H.T
        refl_id = BaseModel.get_refl_id(inputs)
        N = np.bincount(refl_id.flatten(), minlength=len(F)).astype('float32')
        results = ()
        for i, asu in enumerate(self.asu_collection):
            idx = asu_id == i
            idx = idx.flatten()
            output = rs.DataSet(
                {
                    'H': h[idx],
                    'K': k[idx],
                    'L': l[idx],
                    'F': F[idx],
                    'SigF': SigF[idx],
                    'N': N[idx],
                },
                cell=asu.cell,
                spacegroup=asu.spacegroup,
                merged=True,
            ).infer_mtz_dtypes().set_index(['H', 'K', 'L'])
            if params is not None:
                for key in sorted(params.keys()):
                    val = params[key]
                    output[key] = rs.DataSeries(val[idx],
                                                index=output.index,
                                                dtype='R')

            # Remove unobserved refls
            output = output[output.N > 0]

            # Reformat anomalous data
            if asu.anomalous:
                output = output.unstack_anomalous()
                # PHENIX will expect the sf / error keys in a particular order.
                anom_keys = [
                    'F(+)', 'SigF(+)', 'F(-)', 'SigF(-)', 'N(+)', 'N(-)'
                ]
                reorder = anom_keys + [
                    key for key in output if key not in anom_keys
                ]
                output = output[reorder]

            results += (output, )
        return results