def _prob_chain_rule_model_flatten(named_makers):
    """Creates lists of callables suitable for JDSeq."""
    def _make(dist_fn, args):
        if args is None:
            return lambda *_: dist_fn
        if not args:
            return lambda *_: dist_fn()

        def _fn(*xs):
            kwargs = dict([(k, v) for k, v in zip(dist_fn_name, xs)
                           if k in args])
            return dist_fn(**kwargs)

        return _fn

    named_makers = _convert_to_dict(named_makers)
    g = {
        k: None if distribution_util.is_distribution_instance(v) else
        joint_distribution_sequential._get_required_args(v)  # pylint: disable=protected-access
        for k, v in named_makers.items()
    }
    g = _best_order(g)
    dist_fn_name, dist_fn_args = zip(*g)
    dist_fn_args = tuple(None if a is None else tuple(a) for a in dist_fn_args)
    dist_fn_wrapped = tuple(
        _make(named_makers[name], parents) for (name, parents) in g)
    dist_fn = tuple(named_makers.get(n) for n in dist_fn_name)
    return dist_fn, dist_fn_wrapped, dist_fn_args, dist_fn_name
示例#2
0
def _prob_chain_rule_flatten(named_makers):
    """Creates lists of callables suitable for JDSeq."""
    def _make(dist_fn, args):
        if args is None:
            return lambda *_: dist_fn
        if not args:
            return lambda *_: dist_fn()

        def _fn(*xs):
            kwargs = dict(zip(args, reversed(xs[-len(args):])))
            kwargs.pop('_', None)
            return dist_fn(**kwargs)

        return _fn

    named_makers = (named_makers._asdict() if hasattr(named_makers, '_asdict')
                    else dict(named_makers))
    g = {
        k: (None if distribution_util.is_distribution_instance(v) else
            joint_distribution_sequential._get_required_args(v))  # pylint: disable=protected-access
        for k, v in named_makers.items()
    }
    g = _best_order(g)
    dist_fn_name, dist_fn_args = zip(*g)
    dist_fn_args = tuple(None if a is None else tuple(a) for a in dist_fn_args)
    dist_fn_wrapped = tuple(
        _make(named_makers[name], parents) for (name, parents) in g)
    dist_fn = tuple(named_makers.get(n) for n in dist_fn_name)
    return dist_fn, dist_fn_wrapped, dist_fn_args, dist_fn_name
def _prob_chain_rule_model_flatten(named_makers):
  """Creates lists of callables suitable for JDSeq."""
  def _make(dist_fn, args):
    if args is None:
      return lambda *_: dist_fn
    if not args:
      return lambda *_: dist_fn()
    def _fn(*xs):
      kwargs = dict([(k, v) for k, v in zip(dist_fn_name, xs) if k in args])
      return dist_fn(**kwargs)
    return _fn
  named_makers = _convert_to_dict(named_makers)

  previous_keys = []
  parents = type(named_makers)()
  for key, dist_fn in named_makers.items():
    if distribution_util.is_distribution_instance(dist_fn):
      parents[key] = None   # pylint: disable=g-long-lambda
    else:
      parents[key] = joint_distribution_sequential._get_required_args(  # pylint: disable=protected-access
          dist_fn,
          # To ensure an acyclic dependence graph, a dist_fn that takes
          # `**kwargs` is treated as depending on all distributions that were
          # defined above it, but not any defined below it.
          previous_args=previous_keys)
    previous_keys.append(key)

  g = _best_order(parents)
  dist_fn_name, dist_fn_args = zip(*g)
  dist_fn_args = tuple(None if a is None else tuple(a) for a in dist_fn_args)
  dist_fn_wrapped = tuple(_make(named_makers[name], parents)
                          for (name, parents) in g)
  dist_fn = tuple(named_makers.get(n) for n in dist_fn_name)
  return dist_fn, dist_fn_wrapped, dist_fn_args, dist_fn_name
示例#4
0
def _unify_call_signature(i, dist_fn):
  """Relieves arg unpacking burden from call site."""
  if distribution_util.is_distribution_instance(dist_fn):
    return (lambda *_: dist_fn), (), True
  if not callable(dist_fn):
    raise TypeError('{} must be either `tfd.Distribution`-like or '
                    '`callable`.'.format(dist_fn))
  args = _get_required_args(dist_fn)
  num_args = len(args)
  @functools.wraps(dist_fn)
  def wrapped_dist_fn(*xs):
    args = [] if i < 1 else xs[(i - 1)::-1]
    return dist_fn(*args[:num_args])
  return wrapped_dist_fn, args, False
def _maybe_build_joint_distribution(structure_of_distributions):
    """Turns a (potentially nested) structure of dists into a single dist."""
    # Base case: if we already have a Distribution, return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        return structure_of_distributions

    # Otherwise, recursively convert all interior nested structures into JDs.
    outer_structure = tf.nest.map_structure(_maybe_build_joint_distribution,
                                            structure_of_distributions)
    if (hasattr(outer_structure, '_asdict')
            or isinstance(outer_structure, collections.Mapping)):
        return joint_distribution_named.JointDistributionNamed(outer_structure)
    else:
        return joint_distribution_sequential.JointDistributionSequential(
            outer_structure)
示例#6
0
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        return structure_of_distributions

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            independent_joint_distribution_from_structure,
            structure_of_distributions)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict')
            or isinstance(structure_of_distributions, collections.Mapping)):
        return joint_distribution_named.JointDistributionNamed(
            structure_of_distributions, validate_args=validate_args)
    return joint_distribution_sequential.JointDistributionSequential(
        structure_of_distributions, validate_args=validate_args)
示例#7
0
def _unify_call_signature(i, dist_fn):
    """Creates `wrapped_dist_fn` which calls `dist_fn` with all prev nodes.

  Args:
    i: Python `int` corresponding to position in topologically sorted DAG.
    dist_fn: Python `callable` which takes a subset of previously constructed
      distributions (in reverse order) and produces a new distribution instance.

  Returns:
    wrapped_dist_fn: Python `callable` which takes all previous distributions
      (in non reverse order) and produces a  new distribution instance.
    args: `tuple` of `str` representing the arg names of `dist_fn` (and in non
      wrapped, "natural" order). `None` is returned only if the input is not a
      `callable`.
  """
    if distribution_util.is_distribution_instance(dist_fn):
        return (lambda *_: dist_fn), None

    if not callable(dist_fn):
        raise TypeError('{} must be either `tfd.Distribution`-like or '
                        '`callable`.'.format(dist_fn))

    args = _get_required_args(dist_fn)
    if not args:
        return (lambda *_: dist_fn()), ()

    @functools.wraps(dist_fn)
    def wrapped_dist_fn(*xs):
        """Calls `dist_fn` with reversed and truncated args."""
        if i != len(xs):
            raise ValueError(
                'Internal Error: Unexpected number of inputs provided to {}-th '
                'distribution maker (dist_fn: {}, expected: {}, saw: {}).'.
                format(i, dist_fn, i, len(xs)))
        if len(xs) < len(args):
            raise ValueError(
                'Internal Error: Too few inputs provided to {}-th distribution maker '
                '(dist_fn: {}, expected: {}, saw: {}).'.format(
                    i, dist_fn, len(args), len(xs)))
        return dist_fn(*reversed(xs[-len(args):]))

    return wrapped_dist_fn, args
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  batch_ndims=None,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    batch_ndims: Optional integer `Tensor` number of leftmost batch dimensions
      shared across all members of the input structure. If this is specified,
      the returned joint distribution will be an autobatched distribution with
      the given batch rank, and all other dimensions absorbed into the event.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        dist = structure_of_distributions
        if batch_ndims is not None:
            excess_ndims = ps.rank_from_shape(
                dist.batch_shape_tensor()) - batch_ndims
            if tf.get_static_value(
                    excess_ndims) != 0:  # Static value may be None.
                dist = independent.Independent(
                    dist, reinterpreted_batch_ndims=excess_ndims)
        return dist

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            functools.partial(independent_joint_distribution_from_structure,
                              batch_ndims=batch_ndims,
                              validate_args=validate_args),
            structure_of_distributions)

    jdnamed = joint_distribution_named.JointDistributionNamed
    jdsequential = joint_distribution_sequential.JointDistributionSequential
    # Use an autobatched JD if a specific batch rank was requested.
    if batch_ndims is not None:
        jdnamed = functools.partial(
            joint_distribution_auto_batched.JointDistributionNamedAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)
        jdsequential = functools.partial(
            joint_distribution_auto_batched.
            JointDistributionSequentialAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict') or isinstance(
            structure_of_distributions, collections.abc.Mapping)):
        return jdnamed(structure_of_distributions, validate_args=validate_args)
    return jdsequential(structure_of_distributions,
                        validate_args=validate_args)