def log_joint_fn(*args, **kwargs): """Log-probability of inputs according to a joint probability distribution. Args: *args: Positional arguments. They are the model's original inputs and can alternatively be specified as part of `kwargs`. **kwargs: Keyword arguments, where for each key-value pair `k` and `v`, `v` is passed as a `value` to the random variable(s) whose keyword argument `name` during construction is equal to `k`. Returns: Scalar tf.Tensor, which represents the model's log-probability summed over all Edward random variables and their dimensions. Raises: TypeError: If a random variable in the model has no specified value in `**kwargs`. """ log_probs = [] def interceptor(rv_constructor, *rv_args, **rv_kwargs): """Overrides a random variable's `value` and accumulates its log-prob.""" # Set value to keyword argument indexed by `name` (an input tensor). rv_name = rv_kwargs.get("name") if rv_name is None: raise KeyError("Random variable constructor {} has no name " "in its arguments.".format( rv_constructor.__name__)) # If no value is explicitly passed in for an RV, default to the value # from the RV constructor. This may have been set explicitly by the user # or forwarded from a lower-level interceptor. previously_specified_value = rv_kwargs.get("value") value = kwargs.get(rv_name, previously_specified_value) if value is None: raise LookupError( "Keyword argument specifying value for {} is " "missing.".format(rv_name)) rv_kwargs["value"] = value rv = rv_constructor(*rv_args, **rv_kwargs) log_prob = tf.reduce_sum( input_tensor=rv.distribution.log_prob(rv.value)) log_probs.append(log_prob) return rv model_kwargs = _get_function_inputs(model, kwargs) with interception(interceptor): model(*args, **model_kwargs) log_prob = sum(log_probs) return log_prob
def log_joint_fn(*args, **kwargs): """Log-probability of inputs according to a joint probability distribution. Args: *args: Positional arguments. They are the model's original inputs and can alternatively be specified as part of `kwargs`. **kwargs: Keyword arguments, where for each key-value pair `k` and `v`, `v` is passed as a `value` to the random variable(s) whose keyword argument `name` during construction is equal to `k`. Returns: Scalar tf.Tensor, which represents the model's log-probability summed over all Edward random variables and their dimensions. Raises: TypeError: If a random variable in the model has no specified value in `**kwargs`. """ log_probs = [] def interceptor(rv_constructor, *rv_args, **rv_kwargs): """Overrides a random variable's `value` and accumulates its log-prob.""" # Set value to keyword argument indexed by `name` (an input tensor). rv_name = rv_kwargs.get('name') if rv_name is None: raise KeyError('Random variable constructor {} has no name ' 'in its arguments.'.format( rv_constructor.__name__)) value = kwargs.get(rv_name) if value is None: raise LookupError( 'Keyword argument specifying value for {} is ' 'missing.'.format(rv_name)) rv_kwargs['value'] = value rv = rv_constructor(*rv_args, **rv_kwargs) log_prob = tf.reduce_sum( input_tensor=rv.distribution.log_prob(rv.value)) log_probs.append(log_prob) return rv model_kwargs = _get_function_inputs(model, **kwargs) with interception(interceptor): try: model(*args, **model_kwargs) except TypeError as err: raise Exception( 'Wrong number of arguments in log_joint function definition. {}' .format(err)) log_prob = sum(log_probs) return log_prob
def get_trace(model, *args, **kwargs): trace_result = {} def trace(rv_constructor, *rv_args, **rv_kwargs): rv = interceptable(rv_constructor)(*rv_args, **rv_kwargs) name = rv_kwargs['name'] trace_result[name] = rv.value return rv with interception(trace): model(*args, **kwargs) return trace_result
def variational_model(*args): with interception(mean_field): return model(*args)