Exemplo n.º 1
0
def test_getters(laue_inputs, mono_inputs):
    for inputs in laue_inputs, mono_inputs:
        if BaseModel.is_laue(inputs):
            BaseModel.get_harmonic_id(inputs)
            BaseModel.get_wavelength(inputs)
        BaseModel.get_image_id(inputs)
        BaseModel.get_intensities(inputs)
        BaseModel.get_metadata(inputs)
        BaseModel.get_refl_id(inputs)
        BaseModel.get_uncertainties(inputs)
Exemplo n.º 2
0
def test_laue(likelihood_model, prior_model, scaling_model, laue_inputs,
              mc_samples):
    nrefls = np.max(BaseModel.get_refl_id(laue_inputs)) + 1
    n_images = np.max(BaseModel.get_image_id(laue_inputs)) + 1

    #For the students
    dof = 4.
    if likelihood_model == StudentTLikelihood:
        likelihood = likelihood_model(dof)
    else:
        likelihood = likelihood_model()

    if prior_model == WilsonPrior:
        prior = prior_model(
            np.random.choice([True, False], nrefls),
            np.ones(nrefls).astype('float32'),
        )
    elif prior_model == StudentTReferencePrior:
        prior = prior_model(
            np.ones(nrefls).astype('float32'),
            np.ones(nrefls).astype('float32'), dof)
    else:
        prior = prior_model(
            np.ones(nrefls).astype('float32'),
            np.ones(nrefls).astype('float32'),
        )

    mlp_scaler = MLPScaler(2, 3)
    if scaling_model == HybridImageScaler:
        image_scaler = ImageScaler(n_images)
        scaler = HybridImageScaler(mlp_scaler, image_scaler)
    elif scaling_model == MLPScaler:
        scaler = mlp_scaler

    surrogate_posterior = tfd.TruncatedNormal(
        tf.Variable(prior.mean()),
        tfp.util.TransformedVariable(
            prior.stddev() / 10.,
            tfb.Softplus(),
        ),
        low=1e-5,
        high=1e10,
    )

    merger = VariationalMergingModel(surrogate_posterior, prior, likelihood,
                                     scaler, mc_samples)
    ipred = merger(laue_inputs)

    isfinite = np.all(np.isfinite(ipred.numpy()))
    assert isfinite

    merger = VariationalMergingModel(surrogate_posterior, prior, likelihood,
                                     scaler)
    merger.compile('Adam')
Exemplo n.º 3
0
    def get_predictions(self, model, inputs=None):
        """ 
        Extract results from a surrogate_posterior.

        Parameters
        ----------
        model : VariationalMergingModel
            A merging model from careless
        inputs : tuple (optional)
            Inputs for which to make the predictions if None, self.inputs is used.

        Returns
        -------
        predictions : tuple
            A tuple of rs.DataSet objects containing the predictions for each 
            ReciprocalASU contained in self.asu_collection
        """
        if inputs is None:
            inputs = self.inputs

        refl_id = BaseModel.get_refl_id(inputs)
        iobs = BaseModel.get_intensities(inputs).flatten()
        sig_iobs = BaseModel.get_uncertainties(inputs).flatten()
        asu_id, H = self.asu_collection.to_asu_id_and_miller_index(refl_id)
        #ipred = model(inputs)
        ipred, sigipred = model.prediction_mean_stddev(inputs)

        h, k, l = H.T
        results = ()
        for i, asu in enumerate(self.asu_collection):
            idx = asu_id == i
            idx = idx.flatten()
            output = rs.DataSet(
                {
                    'H': h[idx],
                    'K': k[idx],
                    'L': l[idx],
                    'Iobs': iobs[idx],
                    'SigIobs': sig_iobs[idx],
                    'Ipred': ipred[idx],
                    'SigIpred': sigipred[idx],
                },
                cell=asu.cell,
                spacegroup=asu.spacegroup,
                merged=False,
            ).infer_mtz_dtypes().set_index(['H', 'K', 'L'])
            results += (output, )
        return results
Exemplo n.º 4
0
    def get_results(self,
                    surrogate_posterior,
                    inputs=None,
                    output_parameters=True):
        """ 
        Extract results from a surrogate_posterior.

        Parameters
        ----------
        surrogate_posterior : tfd.Distribution
            A tensorflow_probability distribution or similar object with `mean` and `stddev` methods
        inputs : tuple (optional)
            Optionally use a different object from self.inputs to compute the redundancy of reflections.
        output_parameters : bool (optional)
            If True, output the parameters of the surrogate distribution in addition to the 
            moments. 

        Returns
        -------
        results : tuple
            A tuple of rs.DataSet objects containing the results corresponding to each 
            ReciprocalASU contained in self.asu_collection
        """
        if inputs is None:
            inputs = self.inputs
        F = surrogate_posterior.mean().numpy()
        SigF = surrogate_posterior.stddev().numpy()
        params = None
        if output_parameters:
            params = {}
            for k in sorted(surrogate_posterior.parameter_properties()):
                v = surrogate_posterior.parameters[k]
                numpify = lambda x: tf.convert_to_tensor(x).numpy()
                params[k] = numpify(v).flatten() * np.ones(len(F),
                                                           dtype='float32')
        asu_id, H = self.asu_collection.to_asu_id_and_miller_index(
            np.arange(len(F)))
        h, k, l = H.T
        refl_id = BaseModel.get_refl_id(inputs)
        N = np.bincount(refl_id.flatten(), minlength=len(F)).astype('float32')
        results = ()
        for i, asu in enumerate(self.asu_collection):
            idx = asu_id == i
            idx = idx.flatten()
            output = rs.DataSet(
                {
                    'H': h[idx],
                    'K': k[idx],
                    'L': l[idx],
                    'F': F[idx],
                    'SigF': SigF[idx],
                    'N': N[idx],
                },
                cell=asu.cell,
                spacegroup=asu.spacegroup,
                merged=True,
            ).infer_mtz_dtypes().set_index(['H', 'K', 'L'])
            if params is not None:
                for key in sorted(params.keys()):
                    val = params[key]
                    output[key] = rs.DataSeries(val[idx],
                                                index=output.index,
                                                dtype='R')

            # Remove unobserved refls
            output = output[output.N > 0]

            # Reformat anomalous data
            if asu.anomalous:
                output = output.unstack_anomalous()
                # PHENIX will expect the sf / error keys in a particular order.
                anom_keys = [
                    'F(+)', 'SigF(+)', 'F(-)', 'SigF(-)', 'N(+)', 'N(-)'
                ]
                reorder = anom_keys + [
                    key for key in output if key not in anom_keys
                ]
                output = output[reorder]

            results += (output, )
        return results