def transform_dist_if_necessary(dist, state, *, allow_transformed_and_untransformed): if dist.transform is None or dist.model_info.get("autotransformed", False): return dist scoped_name = scopes.variable_name(dist.name) transform = dist.transform transformed_scoped_name = scopes.transformed_variable_name( transform.name, dist.name) if observed_value_in_evaluation(scoped_name, dist, state) is not None: # do not modify a distribution if it is observed # same for programmatically observed # but not for programmatically set to unobserved (when value is None) # but raise if we have transformed value passed in dict if transformed_scoped_name in state.transformed_values: raise EvaluationError( EvaluationError. OBSERVED_VARIABLE_IS_NOT_SUPPRESSED_BUT_ADDITIONAL_TRANSFORMED_VALUE_PASSED .format(scoped_name, transformed_scoped_name)) if scoped_name in state.untransformed_values: raise EvaluationError( EvaluationError. OBSERVED_VARIABLE_IS_NOT_SUPPRESSED_BUT_ADDITIONAL_VALUE_PASSED .format(scoped_name, scoped_name)) return dist if transformed_scoped_name in state.transformed_values: if (not allow_transformed_and_untransformed ) and scoped_name in state.untransformed_values: state.untransformed_values.pop(scoped_name) return make_transformed_model(dist, transform, state) else: return make_untransformed_model(dist, transform, state)
def proceed_distribution( self, dist: distribution.Distribution, state: SamplingState, sample_shape: Union[int, Tuple[int], tf.TensorShape] = None, ) -> Tuple[Any, SamplingState]: if dist.is_anonymous: raise EvaluationError( "Attempting to create an anonymous Distribution") scoped_name = scopes.variable_name(dist.name) if scoped_name is None: raise EvaluationError( "Attempting to create an anonymous Distribution") if (scoped_name in state.discrete_distributions or scoped_name in state.continuous_distributions or scoped_name in state.deterministics_values): raise EvaluationError( "Attempting to create a duplicate variable {!r}, " "this may happen if you forget to use `pm.name_scope()` when calling same " "model/function twice without providing explicit names. If you see this " "error message and the function being called is not wrapped with " "`pm.model`, you should better wrap it to provide explicit name for this model" .format(scoped_name)) if scoped_name in state.observed_values or dist.is_observed: observed_variable = observed_value_in_evaluation( scoped_name, dist, state) if observed_variable is None: # None indicates we pass None to the state.observed_values dict, # might be posterior predictive or programmatically override to exchange observed variable to latent if scoped_name not in state.untransformed_values: # posterior predictive if dist.is_root: return_value = state.untransformed_values[ scoped_name] = dist.get_test_sample( sample_shape=sample_shape) else: return_value = state.untransformed_values[ scoped_name] = dist.get_test_sample() else: # replace observed variable with a custom one return_value = state.untransformed_values[scoped_name] # We also store the name in posterior_predictives just to keep # track of the variables used in posterior predictive sampling state.posterior_predictives.add(scoped_name) state.observed_values.pop(scoped_name) else: if scoped_name in state.untransformed_values: raise EvaluationError( EvaluationError. OBSERVED_VARIABLE_IS_NOT_SUPPRESSED_BUT_ADDITIONAL_VALUE_PASSED .format(scoped_name)) assert_values_compatible_with_distribution( scoped_name, observed_variable, dist) return_value = state.observed_values[ scoped_name] = observed_variable elif scoped_name in state.untransformed_values: return_value = state.untransformed_values[scoped_name] else: if dist.is_root: return_value = state.untransformed_values[ scoped_name] = dist.get_test_sample( sample_shape=sample_shape) else: return_value = state.untransformed_values[ scoped_name] = dist.get_test_sample() if dist._grad_support: state.continuous_distributions[scoped_name] = dist else: state.discrete_distributions[scoped_name] = dist return return_value, state
def modify_distribution(self, dist, model_info, state): """Apply transformations to a distribution.""" dist = super().modify_distribution(dist, model_info, state) if not isinstance(dist, abstract.Distribution): return dist scoped_name = scopes.variable_name(dist.name) if dist.transform is None or dist.model_info.get( # do nothing else if no transform is set "autotransformed", False): # already autotransformed, do nothing else return dist transform = dist.transform transformed_scoped_name = scopes.variable_name( # double underscore stands for transform "__{}_{}".format(transform.name, dist.name)) if observed_value_in_evaluation(scoped_name, dist, state) is not None: # do not modify a distribution if it is observed # same for programmatically observed # but not for programmatically set to unobserved (when value is None) # but raise if we have transformed value passed in dict if transformed_scoped_name in state.transformed_values: raise EvaluationError( EvaluationError. OBSERVED_VARIABLE_IS_NOT_SUPPRESSED_BUT_ADDITIONAL_TRANSFORMED_VALUE_PASSED .format(scoped_name, transformed_scoped_name)) if scoped_name in state.untransformed_values: raise EvaluationError( EvaluationError. OBSERVED_VARIABLE_IS_NOT_SUPPRESSED_BUT_ADDITIONAL_VALUE_PASSED .format(scoped_name, scoped_name)) return dist if transformed_scoped_name in state.transformed_values: # We do not sample in this if branch # 0. do not allow ambiguity in state, make sure only one value is provided to compute logp if (transformed_scoped_name in state.transformed_values and scoped_name in state.untransformed_values): raise EvaluationError( "Found both transformed and untransformed variables in the state: " "'{} and '{}', but need exactly one".format( scoped_name, transformed_scoped_name)) def model(): # 1. now compute all the variables: in the transformed and untransformed space if transformed_scoped_name in state.transformed_values: transformed_value = state.transformed_values[ transformed_scoped_name] untransformed_value = transform.inverse(transformed_value) else: untransformed_value = state.untransformed_values[ scoped_name] transformed_value = transform.forward(untransformed_value) # these lines below state.untransformed_values[scoped_name] = untransformed_value state.transformed_values[ transformed_scoped_name] = transformed_value # disable sampling and save cached results to store for yield dist # once we are done with variables we can yield the value in untransformed space # to the user and also increment the potential # Important: # I have no idea yet, how to make that beautiful. # Here we indicate the distribution is already autotransformed nto to get in the infinite loop dist.model_info["autotransformed"] = True # 2. here decide on logdet computation, this might be effective # with transformed value, but not with an untransformed one # this information is stored in transform.jacobian_preference class attribute # we postpone the computation of logdet as it might have some overhead if transform.jacobian_preference == JacobianPreference.Forward: potential_fn = functools.partial( transform.forward_log_det_jacobian, untransformed_value) coef = -1.0 else: potential_fn = functools.partial( transform.inverse_log_det_jacobian, transformed_value) coef = 1.0 yield distributions.Potential(potential_fn, coef=coef) # 3. final return+yield will return untransformed_value # as it is stored in state.values # Note: we need yield here to make another checks on name duplicates, etc return (yield dist) else: # we gonna sample here, but logp should be computed for the transformed space def model(): # 0. as explained above we indicate we already performed autotransform dist.model_info["autotransformed"] = True # 1. sample a value, as we've checked there is no state provided # we need `dist.model_info["autotransformed"] = True` here not to get in a trouble # the return value is not yet user facing sampled_untransformed_value = yield dist sampled_transformed_value = transform.forward( sampled_untransformed_value) # already stored untransformed value via yield # state.values[scoped_name] = sampled_untransformed_value state.transformed_values[ transformed_scoped_name] = sampled_transformed_value # 2. increment the potential if transform.jacobian_preference == JacobianPreference.Forward: potential_fn = functools.partial( transform.forward_log_det_jacobian, sampled_untransformed_value) coef = -1.0 else: potential_fn = functools.partial( transform.inverse_log_det_jacobian, sampled_transformed_value) coef = 1.0 yield distributions.Potential(potential_fn, coef=coef) # 3. return value to the user return sampled_untransformed_value # return the correct generator model instead of a distribution return model()
def modify_distribution(self, dist: ModelType, model_info: Mapping[str, Any], state: SamplingState) -> ModelType: """Remove the observed distribution values but keep their shapes. Modify observed Distribution instances in the following way: 1) The distribution's shape (batch_shape + event_shape) will be checked for consitency with the supplied observed value's shape. 2) If they are inconsistent, an EvaluationError will be raised. 3) If they are consistent the distribution's observed values' shape will be broadcasted with the distribution's shape to construct a new Distribution instance with no observations. 4) This distribution will be yielded instead of the original incoming dist, and it will be used for posterior predictive sampling Parameters ---------- dist: Union[types.GeneratorType, pymc4.coroutine_model.Model] The model_info: Mapping[str, Any] Either ``dist.model_info`` or ``pymc4.coroutine_model.Model.default_model_info`` if ``dist`` is not a ``pymc4.courutine_model.Model`` instance. state: SamplingState The model's evaluation state. Returns ------- model: Union[types.GeneratorType, pymc4.coroutine_model.Model] The original ``dist`` if it was not an observed ``Distribution`` or the ``Distribution`` with the changed ``batch_shape`` and observations set to ``None``. Raises ------ EvaluationError When ``dist`` and its passed observed value don't have a consistent shape """ dist = super().modify_distribution(dist, model_info, state) # We only modify the shape of Distribution instances that have observed # values dist = transform_dist_if_necessary( dist, state, allow_transformed_and_untransformed=False) if not isinstance(dist, Distribution): return dist scoped_name = scopes.variable_name(dist.name) if scoped_name is None: raise EvaluationError( "Attempting to create an anonymous Distribution") observed_value = observed_value_in_evaluation(scoped_name, dist, state) if observed_value is None: return dist # We set the state's observed value to None to explicitly override # any previously given observed and at the same time, have the # scope_name added to the posterior_predictives set in # self.proceed_distribution state.observed_values[scoped_name] = None # We first check the TFP distribution's shape and compare it with the # observed_value's shape assert_values_compatible_with_distribution(scoped_name, observed_value, dist) # Now we get the broadcasted shape between the observed value and the distribution observed_shape = get_observed_tensor_shape(observed_value) dist_shape = dist.batch_shape + dist.event_shape new_dist_shape = tf.broadcast_static_shape(observed_shape, dist_shape) extra_batch_stack = new_dist_shape[:len(new_dist_shape) - len(dist_shape)] # Now we construct and return the same distribution but setting # observed to None and setting a batch_size that matches the result of # broadcasting the observed and distribution shape batch_stack = extra_batch_stack + (dist.batch_stack if dist.batch_stack is not None else ()) if len(batch_stack) > 0: reinterpreted_batch_ndims = dist.reinterpreted_batch_ndims if dist.event_stack: reinterpreted_batch_ndims += len(extra_batch_stack) new_dist = type(dist)( name=dist.name, transform=dist.transform, observed=None, batch_stack=batch_stack, conditionally_independent=dist.conditionally_independent, event_stack=dist.event_stack, reinterpreted_batch_ndims=reinterpreted_batch_ndims, **dist.conditions, ) else: new_dist = type(dist)( name=dist.name, transform=dist.transform, observed=None, batch_stack=None, conditionally_independent=dist.conditionally_independent, event_stack=dist.event_stack, reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims, **dist.conditions, ) return new_dist