def __init__(self, cell, spacegroup, dmin, anomalous): """ Parameters ---------- cell : gemmi.UnitCell A gemmi.UnitCell instance spacegroup : gemmi.SpaceGroup A gemmi.SpaceGroup instance dmin : float The maximum resolution in Angstroms anomalous : bool Whether the ASU includes Friedel minus acentric reflections """ self.cell = cell self.spacegroup = spacegroup self.dmin = dmin self.anomalous = anomalous self.Hall = rs.utils.generate_reciprocal_asu(self.cell, self.spacegroup, self.dmin, self.anomalous) h, k, l = self.Hall.T lookup_table = rs.DataSet( { 'H': h, 'K': k, 'L': l, 'id': np.arange(len(h)) }, cell=cell, spacegroup=spacegroup, ).compute_multiplicity().label_centrics().compute_dHKL() self.lookup_table = lookup_table
def test_constructor_dataset(data_fmodel, spacegroup, cell, merged): """Test DataSet.__init__() when called with a DataSet""" result = rs.DataSet(data_fmodel, spacegroup=spacegroup, cell=cell, merged=merged) assert_frame_equal(result, data_fmodel) if merged is not None: assert result.merged == merged else: assert result.merged == True # Ensure provided values take precedence if spacegroup: assert result.spacegroup.xhm() == "P 1" assert isinstance(result.spacegroup, gemmi.SpaceGroup) else: assert result.spacegroup == data_fmodel.spacegroup if cell: if isinstance(cell, gemmi.UnitCell): assert result.cell == cell else: assert result.cell.a == cell[0] assert isinstance(result.cell, gemmi.UnitCell) else: assert result.cell == data_fmodel.cell
def test_remove_absences(inplace, hkl_index): """Test DataSet.remove_absences()""" params = (34., 45., 98., 90., 90., 90.) cell = gemmi.UnitCell(*params) sg_1 = gemmi.SpaceGroup(1) sg_19 = gemmi.SpaceGroup(19) Hall = rs.utils.generate_reciprocal_asu(cell, sg_1, 5., anomalous=False) h, k, l = Hall.T absent = rs.utils.is_absent(Hall, sg_19) ds = rs.DataSet({ 'H': h, 'K': k, 'L': l, 'I': np.ones(len(h)), }, spacegroup=sg_19, cell=cell).infer_mtz_dtypes() if hkl_index: ds.set_index(['H', 'K', 'L'], inplace=True) ds_test = ds.remove_absences(inplace=inplace) ds_true = ds[~ds.label_absences().ABSENT] assert len(ds_test) == len(Hall) - absent.sum() assert np.array_equal(ds_test.get_hkls(), ds_true.get_hkls()) if inplace: assert id(ds_test) == id(ds) else: assert id(ds_test) != id(ds)
def test_from_gemmi(data_gemmi): """Test DataSet.from_gemmi()""" result = rs.DataSet.from_gemmi(data_gemmi) expected = rs.DataSet(data_gemmi) assert_frame_equal(result, expected) assert result.spacegroup == expected.spacegroup assert result.cell == expected.cell assert result._index_dtypes == expected._index_dtypes
def dataset_hkl(): """ Build DataSet for testing containing only Miller indices """ hmin, hmax = -5, 5 H = np.mgrid[hmin:hmax+1:2,hmin:hmax+1:2,hmin:hmax+1:2].reshape((3, -1)).T dataset = rs.DataSet({"H": H[:, 0], "K": H[:, 1], "L": H[:, 2]}) dataset.set_index(["H", "K", "L"], inplace=True) return dataset
def test_constructor_gemmi(data_gemmi, spacegroup, cell): """Test DataSet.__init__() when called with a DataSet""" result = rs.DataSet(data_gemmi, spacegroup=spacegroup, cell=cell) # Ensure provided values take precedence if spacegroup: assert result.spacegroup == spacegroup else: assert result.spacegroup == data_gemmi.spacegroup if cell: assert result.cell == cell else: assert result.cell == data_gemmi.cell
def get_predictions(self, model, inputs=None): """ Extract results from a surrogate_posterior. Parameters ---------- model : VariationalMergingModel A merging model from careless inputs : tuple (optional) Inputs for which to make the predictions if None, self.inputs is used. Returns ------- predictions : tuple A tuple of rs.DataSet objects containing the predictions for each ReciprocalASU contained in self.asu_collection """ if inputs is None: inputs = self.inputs refl_id = BaseModel.get_refl_id(inputs) iobs = BaseModel.get_intensities(inputs).flatten() sig_iobs = BaseModel.get_uncertainties(inputs).flatten() asu_id, H = self.asu_collection.to_asu_id_and_miller_index(refl_id) #ipred = model(inputs) ipred, sigipred = model.prediction_mean_stddev(inputs) h, k, l = H.T results = () for i, asu in enumerate(self.asu_collection): idx = asu_id == i idx = idx.flatten() output = rs.DataSet( { 'H': h[idx], 'K': k[idx], 'L': l[idx], 'Iobs': iobs[idx], 'SigIobs': sig_iobs[idx], 'Ipred': ipred[idx], 'SigIpred': sigipred[idx], }, cell=asu.cell, spacegroup=asu.spacegroup, merged=False, ).infer_mtz_dtypes().set_index(['H', 'K', 'L']) results += (output, ) return results
def expand_to_p1(self): """ Generates all symmetrically equivalent reflections. The spacegroup symmetry is set to P1. Returns ------- DataSet """ if not self.merged: raise ValueError( "This function is only applicable for merged DataSets") if not in_asu(self.get_hkls(), spacegroup=self.spacegroup).all(): raise ValueError( "This function is only applicable for reflection data in the reciprocal ASU and anomalous data in a two-column (unstacked) format" ) p1 = rs.DataSet(spacegroup=self.spacegroup, cell=self.cell) # Get all symops, in ascending order by ISYM groupops = self.spacegroup.operations() allops = [op for op in groupops for op in (op, op.negated())] # Apply each symop and drop duplicates with higher ISYM for isym, op in enumerate(allops, 1): ds = self.copy() ds["M/ISYM"] = isym ds["M/ISYM"] = ds["M/ISYM"].astype("M/ISYM") p1 = p1.append(ds.hkl_to_observed(m_isym="M/ISYM")) p1.drop_duplicates(subset=["H", "K", "L"], inplace=True) # Restrict to p1 ASU p1.spacegroup = gemmi.SpaceGroup(1) p1 = p1.loc[in_asu(p1.get_hkls(), spacegroup=p1.spacegroup)] return p1
def run_careless(parser): # We defer all inputs to make sure the parser has priority in modifying tf parameters import tensorflow as tf import numpy as np import reciprocalspaceship as rs from careless.io.manager import DataManager from careless.io.formatter import MonoFormatter,LaueFormatter from careless.models.base import BaseModel from careless.models.merging.surrogate_posteriors import TruncatedNormal from careless.models.merging.variational import VariationalMergingModel from careless.models.scaling.image import HybridImageScaler,ImageScaler from careless.models.scaling.nn import MLPScaler if parser.type == 'poly': df = LaueFormatter.from_parser(parser) elif parser.type == 'mono': df = MonoFormatter.from_parser(parser) inputs,rac = df.format_files(parser.reflection_files) dm = DataManager(inputs, rac, parser=parser) if parser.test_fraction is not None: train,test = dm.split_data_by_refl(parser.test_fraction) else: train,test = dm.inputs,None model = dm.build_model() history = model.train_model( tuple(map(tf.convert_to_tensor, train)), parser.iterations, message="Training", ) for i,ds in enumerate(dm.get_results(model.surrogate_posterior, inputs=train)): filename = parser.output_base + f'_{i}.mtz' ds.write_mtz(filename) filename = parser.output_base + f'_history.csv' history = rs.DataSet(history).to_csv(filename, index_label='step') model.save_weights(parser.output_base + '_weights') import pickle with open(parser.output_base + "_data_manager.pickle", "wb") as out: pickle.dump(dm, out) predictions_data = None if test is not None: for file_id, (ds_train, ds_test) in enumerate(zip( dm.get_predictions(model, train), dm.get_predictions(model, test), )): ds_train['test'] = rs.DataSeries(0, index=ds_train.index, dtype='I') ds_test['test'] = rs.DataSeries(1, index=ds_test.index, dtype='I') filename = parser.output_base + f'_predictions_{file_id}.mtz' ds_train.append(ds_test).write_mtz(filename) else: for file_id, ds_train in enumerate(dm.get_predictions(model, train)): ds_train['test'] = rs.DataSeries(0, index=ds_train.index, dtype='I') filename = parser.output_base + f'_predictions_{file_id}.mtz' ds_train.write_mtz(filename) if parser.merge_half_datasets: scaling_model = model.scaling_model scaling_model.trainable = False xval_data = [None] * len(dm.asu_collection) for repeat in range(parser.half_dataset_repeats): for half_id, half in enumerate(dm.split_data_by_image()): model = dm.build_model(scaling_model=scaling_model) history = model.train_model( tuple(map(tf.convert_to_tensor, half)), parser.iterations, message=f"Merging repeat {repeat+1} half {half_id+1}", ) for file_id,ds in enumerate(dm.get_results(model.surrogate_posterior, inputs=half)): ds['repeat'] = rs.DataSeries(repeat, index=ds.index, dtype='I') ds['half'] = rs.DataSeries(half_id, index=ds.index, dtype='I') if xval_data[file_id] is None: xval_data[file_id] = ds else: xval_data[file_id] = xval_data[file_id].append(ds) for file_id, ds in enumerate(xval_data): filename = parser.output_base + f'_xval_{file_id}.mtz' ds.write_mtz(filename) if parser.embed: from IPython import embed embed(colors='Linux')
unit_cell1 = cryst1.get_unit_cell() refls1 = reflection_table.from_file(refl_file1) elist2 = ExperimentListFactory.from_json_file(expt_file2) cryst2 = elist2.crystals()[0] unit_cell2 = cryst2.get_unit_cell() refls2 = reflection_table.from_file(refl_file2) # Remove reflections not used in refinement refls2 = refls2.select(refls2.get_flags(refls2.flags.used_in_refinement)) print('generating DIALS dataframes') dials_df1 = rs.DataSet( { 'X': refls1['xyzobs.px.value'].parts()[0].as_numpy_array(), 'Y': refls1['xyzobs.px.value'].parts()[1].as_numpy_array(), 'Wavelength': refls1['Wavelength'].as_numpy_array(), 'BATCH': refls1['imageset_id'].as_numpy_array(), }, cell=gemmi.UnitCell(*unit_cell1.parameters()), spacegroup=gemmi.SpaceGroup(cryst1.get_space_group().type( ).universal_hermann_mauguin_symbol())).infer_mtz_dtypes() dials_df2 = rs.DataSet( { 'X': refls2['xyzobs.px.value'].parts()[0].as_numpy_array(), 'Y': refls2['xyzobs.px.value'].parts()[1].as_numpy_array(), 'Wavelength': refls2['Wavelength'].as_numpy_array(), 'BATCH': refls2['imageset_id'].as_numpy_array(), }, cell=gemmi.UnitCell(*unit_cell2.parameters()), spacegroup=gemmi.SpaceGroup(cryst2.get_space_group().type( ).universal_hermann_mauguin_symbol())).infer_mtz_dtypes()
from dials.array_family.flex import reflection_table elist = ExperimentListFactory.from_json_file(expt_file) cryst = elist.crystals()[0] unit_cell = cryst.get_unit_cell() spacegroup = gemmi.SpaceGroup( cryst.get_space_group().type().universal_hermann_mauguin_symbol()) cell = gemmi.UnitCell(*unit_cell.parameters()) refls = reflection_table.from_file(refl_file) refls = refls.select(refls.get_flags(refls.flags.used_in_refinement)) print('generating DIALS dataframe') dials_df = rs.DataSet( { 'X': refls['xyzobs.px.value'].parts()[0].as_numpy_array(), 'Y': refls['xyzobs.px.value'].parts()[1].as_numpy_array(), 'Xcal': refls['xyzcal.px'].parts()[0].as_numpy_array(), 'Ycal': refls['xyzcal.px'].parts()[1].as_numpy_array(), 'BATCH': refls['imageset_id'].as_numpy_array(), }, cell=cell, spacegroup=spacegroup).infer_mtz_dtypes() print('initializing metrics') dials_rmsds = np.zeros(len(precog_rmsds)) rmsd_diff = np.zeros(len(precog_rmsds)) # Iterate by frame and match HKLs, seeing what percentage are correct for i in trange(len(precog_rmsds)): im_dia = dials_df[dials_df['BATCH'] == i] # Get XY positions for refls xy = im_dia[['X', 'Y']].to_numpy(float)
def compute_redundancy(hobs, cell, spacegroup, full_asu=True, anomalous=False, dmin=None): """ Compute the multiplicity of all valid reflections in the reciprocal ASU. Parameters ---------- hobs : np.array(int) An n by 3 array of observed miller indices which are not necessarily in the reciprocal asymmetric unit spacegroup : gemmi.SpaceGroup A gemmi SpaceGroup object. cell : gemmi.UnitCell A gemmi UnitCell object. full_asu : bool (optional) Include all the reflections in the calculation irrespective of whether they were observed. anomalous : bool (optional) Whether or not the data are anomalous. dmin : float (optional) If no dmin is supplied, the maximum resolution reflection will be used. Returns ------- hunique : np.array (int32) An n by 3 array of unique miller indices from hobs. counts : np.array (int32) A length n array of counts for each miller index in hunique. """ hobs = hobs[~rs.utils.is_absent(hobs, spacegroup)] dhkl = rs.utils.compute_dHKL(hobs, cell) if dmin is None: dmin = dhkl.min() hobs = hobs[dhkl >= dmin] decimals = 5. #Round after this many decimals dmin = np.floor(dmin * 10**decimals) * 10**-decimals hobs, isym = rs.utils.hkl_to_asu(hobs, spacegroup) if anomalous: fminus = isym % 2 == 0 hobs[fminus] = -hobs[fminus] mult = rs.DataSet( { 'H': hobs[:, 0], 'K': hobs[:, 1], 'L': hobs[:, 2], 'Count': np.ones(len(hobs)), }, cell=cell, spacegroup=spacegroup).groupby(['H', 'K', 'L']).sum() if full_asu: hall = rs.utils.generate_reciprocal_asu(cell, spacegroup, dmin, anomalous) ASU = rs.DataSet( { 'H': hall[:, 0], 'K': hall[:, 1], 'L': hall[:, 2], 'Count': np.zeros(len(hall)), }, cell=cell, spacegroup=spacegroup).set_index(['H', 'K', 'L']) ASU = ASU.loc[ASU.index.difference(mult.index)] mult = mult.append(ASU) mult = mult.sort_index() return mult.get_hkls(), mult['Count'].to_numpy(np.int32)
cryst = elist.crystals()[0] unit_cell = cryst.get_unit_cell() precog_df.spacegroup = gemmi.SpaceGroup( cryst.get_space_group().type().universal_hermann_mauguin_symbol()) precog_df.cell = gemmi.UnitCell(*unit_cell.parameters()) refls = reflection_table.from_file(refl_file) print('generating DIALS dataframe') dials_df = rs.DataSet( { 'X': refls['xyzobs.px.value'].parts()[0].as_numpy_array(), 'Y': refls['xyzobs.px.value'].parts()[1].as_numpy_array(), 'H': refls['miller_index'].as_vec3_double().parts()[0].as_numpy_array(), 'K': refls['miller_index'].as_vec3_double().parts()[1].as_numpy_array(), 'L': refls['miller_index'].as_vec3_double().parts()[2].as_numpy_array(), 'Wavelength': refls['Wavelength'].as_numpy_array(), 'BATCH': refls['imageset_id'].as_numpy_array(), }, cell=precog_df.cell, spacegroup=precog_df.spacegroup).infer_mtz_dtypes() print('initializing metrics') percent_correct = np.zeros(len(elist)) percent_outliers = np.zeros(len(elist)) percent_misindexed = np.zeros(len(elist)) nspots = np.zeros(len(elist)) nmatch = np.zeros(len(elist))
def test_constructor_empty(): """Test DataSet.__init__()""" result = rs.DataSet() assert len(result) == 0 assert result.spacegroup is None assert result.cell is None
new_expt_filename = args.output + '.expt' new_refl_filename = args.output + '.refl' # Get experiments expts = ExperimentListFactory.from_json_file(expt_filename) new_expts = ExperimentList() # Initialize flex tables for refl files refl_input = flex.reflection_table().from_file(refl_filename) refl_output = refl_input.copy() refl_output["id"] = flex.int([-1] * len(refl_output)) # Initialize data frame dials_df = rs.DataSet({ 'Wavelength': refl_input['Wavelength'], 'ID': refl_input['id'], 'new_ID': [-1] * len(refl_input) }) #.infer_mtz_dtypes() # Generate beams per reflection print(f'Number of rows: {len(dials_df)}') for i, refl in tqdm(dials_df.iterrows()): # New beam per reflection expt = expts[refl['ID'][i]] temp = expt.beam.get_s0() new_expt = expt new_expt.beam = deepcopy(expt.beam) new_expt.beam.set_wavelength(refl['Wavelength'][i]) s0 = (expt.beam.get_s0() / np.linalg.norm(expt.beam.get_s0())) / new_expt.beam.get_wavelength() new_expt.beam.set_s0(s0)
def get_results(self, surrogate_posterior, inputs=None, output_parameters=True): """ Extract results from a surrogate_posterior. Parameters ---------- surrogate_posterior : tfd.Distribution A tensorflow_probability distribution or similar object with `mean` and `stddev` methods inputs : tuple (optional) Optionally use a different object from self.inputs to compute the redundancy of reflections. output_parameters : bool (optional) If True, output the parameters of the surrogate distribution in addition to the moments. Returns ------- results : tuple A tuple of rs.DataSet objects containing the results corresponding to each ReciprocalASU contained in self.asu_collection """ if inputs is None: inputs = self.inputs F = surrogate_posterior.mean().numpy() SigF = surrogate_posterior.stddev().numpy() params = None if output_parameters: params = {} for k in sorted(surrogate_posterior.parameter_properties()): v = surrogate_posterior.parameters[k] numpify = lambda x: tf.convert_to_tensor(x).numpy() params[k] = numpify(v).flatten() * np.ones(len(F), dtype='float32') asu_id, H = self.asu_collection.to_asu_id_and_miller_index( np.arange(len(F))) h, k, l = H.T refl_id = BaseModel.get_refl_id(inputs) N = np.bincount(refl_id.flatten(), minlength=len(F)).astype('float32') results = () for i, asu in enumerate(self.asu_collection): idx = asu_id == i idx = idx.flatten() output = rs.DataSet( { 'H': h[idx], 'K': k[idx], 'L': l[idx], 'F': F[idx], 'SigF': SigF[idx], 'N': N[idx], }, cell=asu.cell, spacegroup=asu.spacegroup, merged=True, ).infer_mtz_dtypes().set_index(['H', 'K', 'L']) if params is not None: for key in sorted(params.keys()): val = params[key] output[key] = rs.DataSeries(val[idx], index=output.index, dtype='R') # Remove unobserved refls output = output[output.N > 0] # Reformat anomalous data if asu.anomalous: output = output.unstack_anomalous() # PHENIX will expect the sf / error keys in a particular order. anom_keys = [ 'F(+)', 'SigF(+)', 'F(-)', 'SigF(-)', 'N(+)', 'N(-)' ] reorder = anom_keys + [ key for key in output if key not in anom_keys ] output = output[reorder] results += (output, ) return results