def iid_sample_fn(*args, **kwargs):
    """Draws iid samples from `fn`."""

    with tf.name_scope('iid_sample_fn'):

      seed = kwargs.pop('seed', None)
      if samplers.is_stateful_seed(seed):
        kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')())
        def pfor_loop_body(_):
          with tf.name_scope('iid_sample_fn_stateful_body'):
            return sample_fn(*args, **kwargs)
      else:
        # If a stateless seed arg is passed, split it into `n` different
        # stateless seeds, so that we don't just get a bunch of copies of the
        # same sample.
        if not JAX_MODE:
          warnings.warn(
              'Saw Tensor seed {}, implying stateless sampling. Autovectorized '
              'functions that use stateless sampling may be quite slow because '
              'the current implementation falls back to an explicit loop. This '
              'will be fixed in the future. For now, you will likely see '
              'better performance from stateful sampling, which you can invoke '
              'by passing a Python `int` seed.'.format(seed))
        seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless')
        def pfor_loop_body(i):
          with tf.name_scope('iid_sample_fn_stateless_body'):
            return sample_fn(*args, seed=tf.gather(seed, i), **kwargs)

      draws = parallel_for.pfor(pfor_loop_body, n)
      return tf.nest.map_structure(unflatten, draws, expand_composites=True)
  def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None):
    # This function additionally depends on:
    #   self._dist_fn_wrapped
    #   self._dist_fn_args
    #   self._always_use_specified_sample_shape
    num_dists = len(self._dist_fn_wrapped)
    if seed is not None and samplers.is_stateful_seed(seed):
      seed_stream = SeedStream(seed, salt='JointDistributionSequential')
    else:
      seed_stream = None
    if seed is not None:
      seeds = samplers.split_seed(seed, n=num_dists,
                                  salt='JointDistributionSequential')
    else:
      seeds = [None] * num_dists
    ds = []
    xs = [None] * num_dists if value is None else list(value)
    if len(xs) != num_dists:
      raise ValueError('Number of `xs`s must match number of '
                       'distributions.')
    for i, (dist_fn, args) in enumerate(zip(self._dist_fn_wrapped,
                                            self._dist_fn_args)):
      ds.append(dist_fn(*xs[:i]))  # Chain rule of probability.

      # Ensure reproducibility even when xs are (partially) set.
      stateful_seed = None if seed_stream is None else seed_stream()

      if xs[i] is None:
        # TODO(b/129364796): We should ignore args prefixed with `_`; this
        # would mean we more often identify when to use `sample_shape=()`
        # rather than `sample_shape=sample_shape`.
        try:  # TODO(b/147874898): Eliminate the stateful fallback 20 Dec 2020.
          xs[i] = ds[-1].sample(
              () if args and not self._always_use_specified_sample_shape
              else sample_shape, seed=seeds[i])
        except TypeError as e:
          if ('Expected int for argument' not in str(e) and
              TENSOR_SEED_MSG_PREFIX not in str(e)) or stateful_seed is None:
            raise

          if not getattr(self, '_resolving_names', False):  # avoid recursion
            self._resolving_names = True
            resolved_names = self._flat_resolve_names()
            self._resolving_names = False
            msg = (
                'Falling back to stateful sampling for distribution #{i} '
                '(0-based) of type `{dist_cls}` with component name '
                '"{component_name}" and `dist.name` "{dist_name}". Please '
                'update to use `tf.random.stateless_*` RNGs. This fallback may '
                'be removed after 20-Dec-2020. ({exc})')
            warnings.warn(msg.format(
                i=i,
                dist_name=ds[-1].name,
                component_name=resolved_names[i],
                dist_cls=type(ds[-1]),
                exc=str(e)))
          xs[i] = ds[-1].sample(
              () if args and not self._always_use_specified_sample_shape
              else sample_shape, seed=stateful_seed)

      else:
        # This signature does not allow kwarg names. Applies
        # `convert_to_tensor` on the next value.
        xs[i] = nest.map_structure_up_to(
            ds[-1].dtype,  # shallow_tree
            lambda x, dtype: tf.convert_to_tensor(x, dtype_hint=dtype),  # func
            xs[i],  # x
            ds[-1].dtype)  # dtype
    # Note: we could also resolve distributions up to the first non-`None` in
    # `self._model_flatten(value)`, however we omit this feature for simplicity,
    # speed, and because it has not yet been requested.
    return ds, xs
Beispiel #3
0
    def _flat_sample_distributions(self,
                                   sample_shape=(),
                                   seed=None,
                                   value=None):
        """Executes `model`, creating both samples and distributions."""
        ds = []
        values_out = []
        if samplers.is_stateful_seed(seed):
            seed_stream = SeedStream(seed, salt='JointDistributionCoroutine')
            if not self._stateful_to_stateless:
                seed = None
        else:
            seed_stream = None  # We got a stateless seed for seed=.

        # TODO(b/166658748): Make _stateful_to_stateless always True (eliminate it).
        if self._stateful_to_stateless and (seed is not None or not JAX_MODE):
            seed = samplers.sanitize_seed(seed,
                                          salt='JointDistributionCoroutine')
        gen = self._model_coroutine()
        index = 0
        d = next(gen)
        if self._require_root and not isinstance(d, self.Root):
            raise ValueError('First distribution yielded by coroutine must '
                             'be wrapped in `Root`.')
        try:
            while True:
                actual_distribution = d.distribution if isinstance(
                    d, self.Root) else d
                ds.append(actual_distribution)
                # Ensure reproducibility even when xs are (partially) set. Always split.
                stateful_sample_seed = None if seed_stream is None else seed_stream(
                )
                if seed is None:
                    stateless_sample_seed = None
                else:
                    stateless_sample_seed, seed = samplers.split_seed(seed)

                if (value is not None and len(value) > index
                        and value[index] is not None):

                    def convert_tree_to_tensor(x, dtype_hint):
                        return tf.convert_to_tensor(x, dtype_hint=dtype_hint)

                    # This signature does not allow kwarg names. Applies
                    # `convert_to_tensor` on the next value.
                    next_value = nest.map_structure_up_to(
                        ds[-1].dtype,  # shallow_tree
                        convert_tree_to_tensor,  # func
                        value[index],  # x
                        ds[-1].dtype)  # dtype_hint
                else:
                    try:
                        next_value = actual_distribution.sample(
                            sample_shape=sample_shape if isinstance(
                                d, self.Root) else (),
                            seed=(stateful_sample_seed
                                  if stateless_sample_seed is None else
                                  stateless_sample_seed))
                    except TypeError as e:
                        if ('Expected int for argument' not in str(e)
                                and TENSOR_SEED_MSG_PREFIX not in str(e)) or (
                                    stateful_sample_seed is None):
                            raise
                        msg = (
                            'Falling back to stateful sampling for distribution #{index} '
                            '(0-based) of type `{dist_cls}` with component name '
                            '{component_name} and `dist.name` "{dist_name}". Please '
                            'update to use `tf.random.stateless_*` RNGs. This fallback may '
                            'be removed after 20-Dec-2020. ({exc})')
                        component_name = (joint_distribution_lib.
                                          get_explicit_name_for_component(
                                              ds[-1]))
                        if component_name is None:
                            component_name = '[None specified]'
                        else:
                            component_name = '"{}"'.format(component_name)
                        warnings.warn(
                            msg.format(index=index,
                                       component_name=component_name,
                                       dist_name=ds[-1].name,
                                       dist_cls=type(ds[-1]),
                                       exc=str(e)))
                        next_value = actual_distribution.sample(
                            sample_shape=sample_shape if isinstance(
                                d, self.Root) else (),
                            seed=stateful_sample_seed)

                if self._validate_args:
                    with tf.control_dependencies(
                            self._assert_compatible_shape(
                                index, sample_shape, next_value)):
                        values_out.append(
                            tf.nest.map_structure(tf.identity, next_value))
                else:
                    values_out.append(next_value)

                index += 1
                d = gen.send(next_value)
        except StopIteration:
            pass
        return ds, values_out