Пример #1
0
    def log_joint_fn(*args, **kwargs):
        """Log-probability of inputs according to a joint probability distribution.

    Args:
      *args: Positional arguments. They are the model's original inputs and can
        alternatively be specified as part of `kwargs`.
      **kwargs: Keyword arguments, where for each key-value pair `k` and `v`,
        `v` is passed as a `value` to the random variable(s) whose keyword
        argument `name` during construction is equal to `k`.

    Returns:
      Scalar tf.Tensor, which represents the model's log-probability summed
      over all Edward random variables and their dimensions.

    Raises:
      TypeError: If a random variable in the model has no specified value in
        `**kwargs`.
    """
        log_probs = []

        def interceptor(rv_constructor, *rv_args, **rv_kwargs):
            """Overrides a random variable's `value` and accumulates its log-prob."""
            # Set value to keyword argument indexed by `name` (an input tensor).
            rv_name = rv_kwargs.get("name")
            if rv_name is None:
                raise KeyError("Random variable constructor {} has no name "
                               "in its arguments.".format(
                                   rv_constructor.__name__))

            # If no value is explicitly passed in for an RV, default to the value
            # from the RV constructor. This may have been set explicitly by the user
            # or forwarded from a lower-level interceptor.
            previously_specified_value = rv_kwargs.get("value")
            value = kwargs.get(rv_name, previously_specified_value)
            if value is None:
                raise LookupError(
                    "Keyword argument specifying value for {} is "
                    "missing.".format(rv_name))
            rv_kwargs["value"] = value

            rv = rv_constructor(*rv_args, **rv_kwargs)
            log_prob = tf.reduce_sum(
                input_tensor=rv.distribution.log_prob(rv.value))
            log_probs.append(log_prob)
            return rv

        model_kwargs = _get_function_inputs(model, kwargs)
        with interception(interceptor):
            model(*args, **model_kwargs)
        log_prob = sum(log_probs)
        return log_prob
    def log_joint_fn(*args, **kwargs):
        """Log-probability of inputs according to a joint probability distribution.

    Args:
      *args: Positional arguments. They are the model's original inputs and can
        alternatively be specified as part of `kwargs`.
      **kwargs: Keyword arguments, where for each key-value pair `k` and `v`,
        `v` is passed as a `value` to the random variable(s) whose keyword
        argument `name` during construction is equal to `k`.

    Returns:
      Scalar tf.Tensor, which represents the model's log-probability summed
      over all Edward random variables and their dimensions.

    Raises:
      TypeError: If a random variable in the model has no specified value in
        `**kwargs`.
    """
        log_probs = []

        def interceptor(rv_constructor, *rv_args, **rv_kwargs):
            """Overrides a random variable's `value` and accumulates its log-prob."""
            # Set value to keyword argument indexed by `name` (an input tensor).
            rv_name = rv_kwargs.get('name')
            if rv_name is None:
                raise KeyError('Random variable constructor {} has no name '
                               'in its arguments.'.format(
                                   rv_constructor.__name__))
            value = kwargs.get(rv_name)
            if value is None:
                raise LookupError(
                    'Keyword argument specifying value for {} is '
                    'missing.'.format(rv_name))
            rv_kwargs['value'] = value

            rv = rv_constructor(*rv_args, **rv_kwargs)
            log_prob = tf.reduce_sum(
                input_tensor=rv.distribution.log_prob(rv.value))
            log_probs.append(log_prob)
            return rv

        model_kwargs = _get_function_inputs(model, **kwargs)
        with interception(interceptor):
            try:
                model(*args, **model_kwargs)
            except TypeError as err:
                raise Exception(
                    'Wrong number of arguments in log_joint function definition. {}'
                    .format(err))
        log_prob = sum(log_probs)
        return log_prob
def get_trace(model, *args, **kwargs):

    trace_result = {}

    def trace(rv_constructor, *rv_args, **rv_kwargs):
        rv = interceptable(rv_constructor)(*rv_args, **rv_kwargs)
        name = rv_kwargs['name']
        trace_result[name] = rv.value
        return rv

    with interception(trace):
        model(*args, **kwargs)

    return trace_result
 def variational_model(*args):
     with interception(mean_field):
         return model(*args)