示例#1
0
def _maybe_check_valid_map_values(map_values, validate_args):
    """Validate `map_values` if `validate_args`==True."""
    assertions = []

    message = 'Rank of map_values must be 1.'
    if tensorshape_util.rank(map_values.shape) is not None:
        if tensorshape_util.rank(map_values.shape) != 1:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(map_values, 1, message=message))

    message = 'Size of map_values must be greater than 0.'
    if tensorshape_util.num_elements(map_values.shape) is not None:
        if tensorshape_util.num_elements(map_values.shape) == 0:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(tf.size(map_values), 0,
                                       message=message))

    if validate_args:
        assertions.append(
            assert_util.assert_equal(
                tf.math.is_strictly_increasing(map_values),
                True,
                message='map_values is not strictly increasing.'))

    return assertions
示例#2
0
def validate_init_args_statically(distribution, batch_shape):
    """Helper to __init__ which makes or raises assertions."""
    if tensorshape_util.rank(batch_shape.shape) is not None:
        if tensorshape_util.rank(batch_shape.shape) != 1:
            raise ValueError('`batch_shape` must be a vector '
                             '(saw rank: {}).'.format(
                                 tensorshape_util.rank(batch_shape.shape)))

    batch_shape_static = tensorshape_util.constant_value_as_shape(batch_shape)
    batch_size_static = tensorshape_util.num_elements(batch_shape_static)
    dist_batch_size_static = tensorshape_util.num_elements(
        distribution.batch_shape)

    if batch_size_static is not None and dist_batch_size_static is not None:
        if batch_size_static != dist_batch_size_static:
            raise ValueError('`batch_shape` size ({}) must match '
                             '`distribution.batch_shape` size ({}).'.format(
                                 batch_size_static, dist_batch_size_static))

    if tensorshape_util.dims(batch_shape_static) is not None:
        if any(
                tf.compat.dimension_value(dim) is not None
                and tf.compat.dimension_value(dim) < 1
                for dim in batch_shape_static):
            raise ValueError('`batch_shape` elements must be >=-1.')
示例#3
0
 def _event_shape_tensor(self):
     with tf.control_dependencies(self._assertions):
         event_sizes = [
             tf.reduce_prod(d.event_shape_tensor())  # pylint: disable=g-complex-comprehension
             if tensorshape_util.num_elements(d.event_shape) is None else
             tensorshape_util.num_elements(d.event_shape)
             for d in self._distributions
         ]
         return tf.reduce_sum(event_sizes)[tf.newaxis]
示例#4
0
  def __init__(self,
               distribution_fn,
               sample0=None,
               num_steps=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Autoregressive"):
    """Construct an `Autoregressive` distribution.

    Args:
      distribution_fn: Python `callable` which constructs a
        `tfd.Distribution`-like instance from a `Tensor` (e.g.,
        `sample0`). The function must respect the "autoregressive property",
        i.e., there exists a permutation of event such that each coordinate is a
        diffeomorphic function of on preceding coordinates.
      sample0: Initial input to `distribution_fn`; used to
        build the distribution in `__init__` which in turn specifies this
        distribution's properties, e.g., `event_shape`, `batch_shape`, `dtype`.
        If unspecified, then `distribution_fn` should be default constructable.
      num_steps: Number of times `distribution_fn` is composed from samples,
        e.g., `num_steps=2` implies
        `distribution_fn(distribution_fn(sample0).sample(n)).sample()`.
      validate_args: Python `bool`.  Whether to validate input with asserts.
        If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "Autoregressive".

    Raises:
      ValueError: if `num_steps` and
        `num_elements(distribution_fn(sample0).event_shape)` are both `None`.
      ValueError: if `num_steps < 1`.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      self._distribution_fn = distribution_fn
      self._sample0 = sample0
      self._distribution0 = (distribution_fn() if sample0 is None
                             else distribution_fn(sample0))
      if num_steps is None:
        num_steps = tensorshape_util.num_elements(
            self._distribution0.event_shape)
        if num_steps is None:
          raise ValueError("distribution_fn must generate a distribution "
                           "with fully known `event_shape`.")
      if num_steps < 1:
        raise ValueError("num_steps ({}) must be at least 1.".format(num_steps))
      self._num_steps = num_steps
    super(Autoregressive, self).__init__(
        dtype=self._distribution0.dtype,
        reparameterization_type=self._distribution0.reparameterization_type,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=self._distribution0._graph_parents,  # pylint: disable=protected-access
        name=name)
    def _sample_n(self, n, seed=None):
        distribution0 = self._get_distribution0()

        if self._num_steps is not None:
            num_steps = tf.convert_to_tensor(self._num_steps)
            num_steps_static = tf.get_static_value(num_steps)
        else:
            num_steps_static = tensorshape_util.num_elements(
                distribution0.event_shape)
            if num_steps_static is None:
                num_steps = tf.reduce_prod(distribution0.event_shape_tensor())

        seed = SeedStream(seed, salt='Autoregressive')()
        samples = distribution0.sample(n, seed=seed)
        if num_steps_static is not None:
            for _ in range(num_steps_static):
                # pylint: disable=not-callable
                samples = self.distribution_fn(samples).sample(seed=seed)
        else:
            samples = tf.foldl(
                # pylint: disable=not-callable
                lambda s, _: self.distribution_fn(s).sample(seed=seed),
                elems=tf.range(0, num_steps),
                initializer=samples)
        return samples
示例#6
0
 def _event_shape(self):
   event_size = [
       tensorshape_util.num_elements(d.event_shape) for d in self.distributions
   ]
   if any(r is None for r in event_size):
     return tf.TensorShape([None])
   return tf.TensorShape([sum(event_size)])
示例#7
0
 def _num_steps_deprecated_behavior(self):
     distribution0 = self._get_distribution0()
     num_steps_static = tensorshape_util.num_elements(
         distribution0.event_shape)
     if num_steps_static is not None:
         return num_steps_static
     return tf.reduce_prod(distribution0.event_shape_tensor())
示例#8
0
def variables_summary(variables, name=None):
    """Returns a list of summarizing `str`s."""
    trainable_size = collections.defaultdict(lambda: 0)
    lines = []
    if name is not None:
        lines.append(' '.join(['=' * 3, name, '=' * 50]))
    fmt = '{: >6} {:20} {:5} {:40}'
    lines.append(fmt.format(
        'SIZE',
        'SHAPE',
        'TRAIN',
        'NAME',
    ))
    for v in tf.nest.flatten(variables):
        num_elements = tensorshape_util.num_elements(v.shape)
        if v.trainable:
            trainable_size[v.dtype.base_dtype] += num_elements
        lines.append(
            fmt.format(
                num_elements,
                str(tensorshape_util.as_list(v.shape)),
                str(v.trainable),
                v.name,
            ))
    bytes_ = sum([k.size * v for k, v in trainable_size.items()])
    cnt = sum([v for v in trainable_size.values()])
    lines.append('trainable size: {}  /  {:.3f} MiB  /  {}'.format(
        cnt,
        bytes_ / 2**20,
        '{' + ', '.join(
            ['{}: {}'.format(k.name, v)
             for k, v in trainable_size.items()]) + '}',
    ))
    return '\n'.join(lines)
示例#9
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
  """Helper to validate block sizes."""
  block_sizes = ps.convert_to_shape_tensor(
      block_sizes, name='block_sizes', dtype_hint=tf.int32)
  block_sizes_shape = block_sizes.shape
  if tensorshape_util.is_fully_defined(block_sizes_shape):
    if (tensorshape_util.rank(block_sizes_shape) != 1 or
        (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))):
      raise ValueError(
          '`block_sizes` must be `None`, or a vector of the same length as '
          '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
          'length {}'.format(block_sizes_shape, len(bijectors)))
    return block_sizes

  elif validate_args:
    message = ('`block_sizes` must be `None`, or a vector of the same length '
               'as `bijectors`.')
    with tf.control_dependencies([
        assert_util.assert_equal(
            tf.size(block_sizes), len(bijectors), message=message),
        assert_util.assert_equal(tf.rank(block_sizes), 1)
    ]):
      block_sizes = tf.identity(block_sizes)

  # Set the shape if missing to pass statically known structure to split.
  tensorshape_util.set_shape(block_sizes, [len(bijectors)])
  return block_sizes
示例#10
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
    """Helper to validate block sizes."""
    block_sizes_shape = block_sizes.shape
    if tensorshape_util.is_fully_defined(block_sizes_shape):
        if (tensorshape_util.rank(block_sizes_shape) != 1
                or (tensorshape_util.num_elements(block_sizes_shape) !=
                    len(bijectors))):
            raise ValueError(
                '`block_sizes` must be `None`, or a vector of the same length as '
                '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
                'length {}'.format(block_sizes_shape, len(bijectors)))
        return block_sizes
    elif validate_args:
        message = (
            '`block_sizes` must be `None`, or a vector of the same length '
            'as `bijectors`.')
        with tf.control_dependencies([
                assert_util.assert_equal(tf.size(input=block_sizes),
                                         len(bijectors),
                                         message=message),
                assert_util.assert_equal(tf.rank(block_sizes), 1)
        ]):
            return tf.identity(block_sizes)
    else:
        return block_sizes
示例#11
0
def _size(input, out_type=tf.int32, name=None):  # pylint: disable=redefined-builtin
  if not hasattr(input, 'shape'):
    x = np.array(input)
    input = tf.convert_to_tensor(input) if x.dtype is np.object else x
  n = tensorshape_util.num_elements(tf.TensorShape(input.shape))
  if n is None:
    return tf.size(input, out_type=out_type, name=name)
  return np.array(n).astype(_numpy_dtype(out_type))
示例#12
0
    def _forward(self, x):
        static_event_size = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(
                x.shape, self._event_ndims)[-self._event_ndims:])

        if self._unroll_loop:
            if not static_event_size:
                raise ValueError(
                    "The final {} dimensions of `x` must be known at graph "
                    "construction time if `unroll_loop=True`. `x.shape: {!r}`".
                    format(self._event_ndims, x.shape))
            y = tf.zeros_like(x, name="y0")

            for _ in range(static_event_size):
                shift, log_scale = self._shift_and_log_scale_fn(y)
                # next_y = scale * x + shift
                next_y = x
                if log_scale is not None:
                    next_y *= tf.exp(log_scale)
                if shift is not None:
                    next_y += shift
                y = next_y
            return y

        event_size = tf.reduce_prod(input_tensor=tf.shape(
            input=x)[-self._event_ndims:])
        y0 = tf.zeros_like(x, name="y0")
        # call the template once to ensure creation
        _ = self._shift_and_log_scale_fn(y0)

        def _loop_body(index, y0):
            """While-loop body for autoregression calculation."""
            # Set caching device to avoid re-getting the tf.Variable for every while
            # loop iteration.
            with tf.compat.v1.variable_scope(
                    tf.compat.v1.get_variable_scope()) as vs:
                if vs.caching_device is None and not tf.executing_eagerly():
                    vs.set_caching_device(lambda op: op.device)
                shift, log_scale = self._shift_and_log_scale_fn(y0)
            y = x
            if log_scale is not None:
                y *= tf.exp(log_scale)
            if shift is not None:
                y += shift
            return index + 1, y

        # If the event size is available at graph construction time, we can inform
        # the graph compiler of the maximum number of steps. If not,
        # static_event_size will be None, and the maximum_iterations argument will
        # have no effect.
        _, y = tf.while_loop(cond=lambda index, _: index < event_size,
                             body=_loop_body,
                             loop_vars=(0, y0),
                             maximum_iterations=static_event_size)
        return y
示例#13
0
def reshapes_of(draw, shape, max_ndims=4):
  """Strategy for valid reshapes of the given shape, rank at most max_ndims."""
  factors = draw(hps.permutations(
      prime_factors(tensorshape_util.num_elements(shape))))
  split_points = sorted(draw(
      hps.lists(hps.integers(min_value=0, max_value=len(factors)),
                min_size=0, max_size=max_ndims - 1)))
  result = ()
  for start, stop in zip([0] + split_points, split_points + [len(factors)]):
    result += (int(np.prod(factors[start:stop])),)
  return result
示例#14
0
def rank_from_shape(shape_tensor_fn, tensorshape=None):
  """Computes `rank` given a `Tensor`'s `shape`."""

  if tensorshape is None:
    shape_tensor = (shape_tensor_fn() if callable(shape_tensor_fn)
                    else shape_tensor_fn)
    if (hasattr(shape_tensor, 'shape') and
        hasattr(shape_tensor.shape, 'num_elements')):
      ndims_ = tensorshape_util.num_elements(shape_tensor.shape)
    else:
      ndims_ = len(shape_tensor)
    ndims_fn = lambda: tf.size(shape_tensor)
  else:
    ndims_ = tensorshape_util.rank(tensorshape)
    ndims_fn = lambda: tf.size(  # pylint: disable=g-long-lambda
        shape_tensor_fn() if callable(shape_tensor_fn) else shape_tensor_fn)
  return ndims_fn() if ndims_ is None else ndims_
示例#15
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    distributions = self.poisson_and_mixture_distributions()
    dist, mixture_dist = distributions
    batch_size = tensorshape_util.num_elements(self.batch_shape)
    if batch_size is None:
      batch_size = tf.reduce_prod(
          self._batch_shape_tensor(distributions=distributions))
    # We need to 'sample extra' from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    mixture_seed, poisson_seed = samplers.split_seed(
        seed, salt='PoissonLogNormalQuadratureCompound')
    ids = mixture_dist.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                mixture_dist.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=mixture_seed)
    # We need to flatten batch dims in case mixture_dist has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = tf.range(
        start=0,
        limit=batch_size * self._quadrature_size,
        delta=self._quadrature_size,
        dtype=ids.dtype)
    ids = ids + offset
    rate = tf.gather(tf.reshape(dist.rate_parameter(), shape=[-1]), ids)
    rate = tf.reshape(
        rate, shape=concat_vectors([n], self._batch_shape_tensor(
            distributions=distributions)))
    return samplers.poisson(
        shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
示例#16
0
    def _sample_n(self, n, seed=None):
        distribution0 = self._get_distribution0()

        if self._num_steps is not None:
            num_steps = tf.convert_to_tensor(self._num_steps)
            num_steps_static = tf.get_static_value(num_steps)
        else:
            num_steps_static = tensorshape_util.num_elements(
                distribution0.event_shape)
            if num_steps_static is None:
                num_steps = tf.reduce_prod(distribution0.event_shape_tensor())

        stateless_seed = samplers.sanitize_seed(seed, salt='Autoregressive')
        stateful_seed = None
        try:
            samples = distribution0.sample(n, seed=stateless_seed)
            is_stateful_sampler = False
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            msg = (
                'Falling back to stateful sampling for `distribution_fn(sample0)` of '
                'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                'This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(distribution0.name, type(distribution0), str(e)))
            stateful_seed = SeedStream(seed, salt='Autoregressive')()
            samples = distribution0.sample(n, seed=stateful_seed)
            is_stateful_sampler = True

        seed = stateful_seed if is_stateful_sampler else stateless_seed

        if num_steps_static is not None:
            for _ in range(num_steps_static):
                # pylint: disable=not-callable
                samples = self.distribution_fn(samples).sample(seed=seed)
        else:
            # pylint: disable=not-callable
            samples = tf.foldl(
                lambda s, _: self.distribution_fn(s).sample(seed=seed),
                elems=tf.range(0, num_steps),
                initializer=samples)
        return samples
    def _forward(self, x, **kwargs):
        static_event_size = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(
                x.shape, self._event_ndims)[-self._event_ndims:])

        if self._unroll_loop:
            if not static_event_size:
                raise ValueError(
                    'The final {} dimensions of `x` must be known at graph '
                    'construction time if `unroll_loop=True`. `x.shape: {!r}`'.
                    format(self._event_ndims, x.shape))
            y = tf.zeros_like(x, name='y0')

            for _ in range(static_event_size):
                y = self._bijector_fn(y, **kwargs).forward(x)
            return y

        event_size = tf.reduce_prod(tf.shape(x)[-self._event_ndims:])
        y0 = tf.zeros_like(x, name='y0')
        # call the template once to ensure creation
        if not tf.executing_eagerly():
            _ = self._bijector_fn(y0, **kwargs).forward(y0)

        def _loop_body(index, y0):
            """While-loop body for autoregression calculation."""
            # Set caching device to avoid re-getting the tf.Variable for every while
            # loop iteration.
            with tf1.variable_scope(tf1.get_variable_scope()) as vs:
                if vs.caching_device is None and not tf.executing_eagerly():
                    vs.set_caching_device(lambda op: op.device)
                bijector = self._bijector_fn(y0, **kwargs)
            y = bijector.forward(x)
            return index + 1, y

        # If the event size is available at graph construction time, we can inform
        # the graph compiler of the maximum number of steps. If not,
        # static_event_size will be None, and the maximum_iterations argument will
        # have no effect.
        _, y = tf.while_loop(cond=lambda index, _: index < event_size,
                             body=_loop_body,
                             loop_vars=(0, y0),
                             maximum_iterations=static_event_size)
        return y
示例#18
0
def rank_from_shape(shape_tensor_fn, tensorshape=None):
  """Computes `rank` given a `Tensor`'s `shape`."""
  # Note: this function will implicitly interpret scalar "shapes" as length-1
  # vectors.
  if tensorshape is None:
    shape_tensor = (shape_tensor_fn() if callable(shape_tensor_fn)
                    else shape_tensor_fn)
    shape_tensor_ = tf.get_static_value(shape_tensor)
    if shape_tensor_ is not None:
      shape_tensor = np.int32(shape_tensor_)
    elif not hasattr(shape_tensor, 'shape'):
      shape_tensor = tf.convert_to_tensor(shape_tensor)
    ndims_ = tensorshape_util.num_elements(shape_tensor.shape)
    ndims_fn = lambda: tf.size(shape_tensor)
  else:
    ndims_ = tensorshape_util.rank(tensorshape)
    ndims_fn = lambda: tf.size(  # pylint: disable=g-long-lambda
        shape_tensor_fn() if callable(shape_tensor_fn) else shape_tensor_fn)
  return ndims_fn() if ndims_ is None else np.int32(ndims_)
示例#19
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code
            # path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seed)
            stream = SeedStream(seed, salt='Mixture')

            for c in range(self.num_components):
                samples.append(self.components[c].sample(n, seed=stream()))
            stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
            x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
            npdt = dtype_util.as_numpy_dtype(x.dtype)
            mask = tf.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=npdt(1),
                off_value=npdt(0))  # [n, B, k]
            mask = distribution_util.pad_mixture_dimensions(
                mask, self, self._cat,
                tensorshape_util.rank(
                    self._static_event_shape))  # [n, B, k, [1]*e]
            return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        n = tf.convert_to_tensor(n, name='n')
        static_n = tf.get_static_value(n)
        n = int(static_n) if static_n is not None else n
        cat_samples = self.cat.sample(n, seed=seed)

        static_samples_shape = cat_samples.shape
        if tensorshape_util.is_fully_defined(static_samples_shape):
            samples_shape = tensorshape_util.as_list(static_samples_shape)
            samples_size = tensorshape_util.num_elements(static_samples_shape)
        else:
            samples_shape = tf.shape(cat_samples)
            samples_size = tf.size(cat_samples)
        static_batch_shape = self.batch_shape
        if tensorshape_util.is_fully_defined(static_batch_shape):
            batch_shape = tensorshape_util.as_list(static_batch_shape)
            batch_size = tensorshape_util.num_elements(static_batch_shape)
        else:
            batch_shape = tf.shape(cat_samples)[1:]
            batch_size = tf.reduce_prod(batch_shape)
        static_event_shape = self.event_shape
        if tensorshape_util.is_fully_defined(static_event_shape):
            event_shape = np.array(
                tensorshape_util.as_list(static_event_shape), dtype=np.int32)
        else:
            event_shape = None

        # Get indices into the raw cat sampling tensor. We will
        # need these to stitch sample values back out after sampling
        # within the component partitions.
        samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                         samples_shape)

        # Partition the raw indices so that we can use
        # dynamic_stitch later to reconstruct the samples from the
        # known partitions.
        partitioned_samples_indices = tf.dynamic_partition(
            data=samples_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)

        # Copy the batch indices n times, as we will need to know
        # these to pull out the appropriate rows within the
        # component partitions.
        batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]),
                                       samples_shape)

        # Explanation of the dynamic partitioning below:
        #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
        # Suppose partitions are:
        #     [1 1 0 0 1 1]
        # After partitioning, batch indices are cut as:
        #     [batch_indices[x] for x in 2, 3]
        #     [batch_indices[x] for x in 0, 1, 4, 5]
        # i.e.
        #     [1 1] and [0 0 0 0]
        # Now we sample n=2 from part 0 and n=4 from part 1.
        # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
        # and for part 1 we want samples from batch entries 0, 0, 0, 0
        #   (samples 0, 1, 2, 3).
        partitioned_batch_indices = tf.dynamic_partition(
            data=batch_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)
        samples_class = [None for _ in range(self.num_components)]

        stream = SeedStream(seed, salt='Mixture')

        for c in range(self.num_components):
            n_class = tf.size(partitioned_samples_indices[c])
            samples_class_c = self.components[c].sample(n_class, seed=stream())

            if event_shape is None:
                batch_ndims = prefer_static.rank_from_shape(batch_shape)
                event_shape = tf.shape(samples_class_c)[1 + batch_ndims:]

            # Pull out the correct batch entries from each index.
            # To do this, we may have to flatten the batch shape.

            # For sample s, batch element b of component c, we get the
            # partitioned batch indices from
            # partitioned_batch_indices[c]; and shift each element by
            # the sample index. The final lookup can be thought of as
            # a matrix gather along locations (s, b) in
            # samples_class_c where the n_class rows correspond to
            # samples within this component and the batch_size columns
            # correspond to batch elements within the component.
            #
            # Thus the lookup index is
            #   lookup[c, i] = batch_size * s[i] + b[c, i]
            # for i = 0 ... n_class[c] - 1.
            lookup_partitioned_batch_indices = (
                batch_size * tf.range(n_class) + partitioned_batch_indices[c])
            samples_class_c = tf.reshape(
                samples_class_c,
                tf.concat([[n_class * batch_size], event_shape], 0))
            samples_class_c = tf.gather(samples_class_c,
                                        lookup_partitioned_batch_indices,
                                        name='samples_class_c_gather')
            samples_class[c] = samples_class_c

        # Stitch back together the samples across the components.
        lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices,
                                         data=samples_class)
        # Reshape back to proper sample, batch, and event shape.
        ret = tf.reshape(lhs_flat_ret,
                         tf.concat([samples_shape, event_shape], 0))
        tensorshape_util.set_shape(
            ret,
            tensorshape_util.concatenate(static_samples_shape,
                                         self.event_shape))
        return ret
示例#20
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # For `logits` and `probs`, we only want to have an assertion on what the
        # user actually passed. For now, we access the underlying categorical's
        # _logits and _probs directly. After the 2019-10-01 deprecation, it would
        # also work to use .logits() and .probs().
        logits = self._categorical._logits
        probs = self._categorical._probs
        outcomes = self._outcomes
        validate_args = self._validate_args

        # Build all shape and dtype checks during the `is_init` call.
        if is_init:

            def validate_equal_last_dim(tensor_a, tensor_b, message):
                event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
                event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
                if event_size_a is not None and event_size_b is not None:
                    if event_size_a != event_size_b:
                        raise ValueError(message)
                elif validate_args:
                    return assert_util.assert_equal(tf.shape(tensor_a)[-1],
                                                    tf.shape(tensor_b)[-1],
                                                    message=message)

            message = 'Size of outcomes must be greater than 0.'
            if tensorshape_util.num_elements(outcomes.shape) is not None:
                if tensorshape_util.num_elements(outcomes.shape) == 0:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    tf.assert_greater(tf.size(outcomes), 0, message=message))

            if logits is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    # pylint: disable=protected-access
                    self._categorical._logits,
                    # pylint: enable=protected-access
                    message=
                    'Last dimension of outcomes and logits must be equal size.'
                )
                if maybe_assert:
                    assertions.append(maybe_assert)

            if probs is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    probs,
                    message=
                    'Last dimension of outcomes and probs must be equal size.')
                if maybe_assert:
                    assertions.append(maybe_assert)

            message = 'Rank of outcomes must be 1.'
            ndims = tensorshape_util.rank(outcomes.shape)
            if ndims is not None:
                if ndims != 1:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    assert_util.assert_rank(outcomes, 1, message=message))

        if not validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(outcomes):
            assertions.append(
                assert_util.assert_equal(
                    tf.math.is_strictly_increasing(outcomes),
                    True,
                    message='outcomes is not strictly increasing.'))

        return assertions
示例#21
0
def _replace_event_shape_in_tensorshape(
    input_tensorshape, event_shape_in, event_shape_out):
  """Replaces the event shape dims of a `TensorShape`.

  Args:
    input_tensorshape: a `TensorShape` instance in which to attempt replacing
      event shape.
    event_shape_in: `Tensor` shape representing the event shape expected to
      be present in (rightmost dims of) `tensorshape_in`. Must be compatible
      with the rightmost dims of `tensorshape_in`.
    event_shape_out: `Tensor` shape representing the new event shape, i.e.,
      the replacement of `event_shape_in`,

  Returns:
    output_tensorshape: `TensorShape` with the rightmost `event_shape_in`
      replaced by `event_shape_out`. Might be partially defined, i.e.,
      `TensorShape(None)`.
    is_validated: Python `bool` indicating static validation happened.

  Raises:
    ValueError: if we can determine the event shape portion of
      `tensorshape_in` as well as `event_shape_in` both statically, and they
      are not compatible. "Compatible" here means that they are identical on
      any dims that are not -1 in `event_shape_in`.
  """
  event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape)
  if tensorshape_util.rank(
      input_tensorshape) is None or event_shape_in_ndims is None:
    return tf.TensorShape(None), False  # Not is_validated.

  input_non_event_ndims = tensorshape_util.rank(
      input_tensorshape) - event_shape_in_ndims
  if input_non_event_ndims < 0:
    raise ValueError(
        'Input has fewer ndims ({}) than event shape ndims ({}).'.format(
            tensorshape_util.rank(input_tensorshape), event_shape_in_ndims))

  input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims]
  input_event_tensorshape = input_tensorshape[input_non_event_ndims:]

  # Check that `input_event_shape_` and `event_shape_in` are compatible in the
  # sense that they have equal entries in any position that isn't a `-1` in
  # `event_shape_in`. Note that our validations at construction time ensure
  # there is at most one such entry in `event_shape_in`.
  event_shape_in_ = tf.get_static_value(event_shape_in)
  is_validated = (
      tensorshape_util.is_fully_defined(input_event_tensorshape) and
      event_shape_in_ is not None)
  if is_validated:
    input_event_shape_ = np.int32(input_event_tensorshape)
    mask = event_shape_in_ >= 0
    explicit_input_event_shape_ = input_event_shape_[mask]
    explicit_event_shape_in_ = event_shape_in_[mask]
    if not all(explicit_input_event_shape_ == explicit_event_shape_in_):
      raise ValueError(
          'Input `event_shape` does not match `event_shape_in`. '
          '({} vs {}).'.format(input_event_shape_, event_shape_in_))

  event_tensorshape_out = tensorshape_util.constant_value_as_shape(
      event_shape_out)
  if tensorshape_util.rank(event_tensorshape_out) is None:
    output_tensorshape = tf.TensorShape(None)
  else:
    output_tensorshape = tensorshape_util.concatenate(
        input_non_event_tensorshape, event_tensorshape_out)

  return output_tensorshape, is_validated
示例#22
0
def _replace_event_shape_in_shape_tensor(
    input_shape, event_shape_in, event_shape_out, validate_args):
  """Replaces the rightmost dims in a `Tensor` representing a shape.

  Args:
    input_shape: a rank-1 `Tensor` of integers
    event_shape_in: the event shape expected to be present in rightmost dims
      of `shape_in`.
    event_shape_out: the event shape with which to replace `event_shape_in` in
      the rightmost dims of `input_shape`.
    validate_args: Python `bool` indicating whether arguments should
      be checked for correctness.

  Returns:
    output_shape: A rank-1 integer `Tensor` with the same contents as
      `input_shape` except for the event dims, which are replaced with
      `event_shape_out`.
  """
  output_tensorshape, is_validated = _replace_event_shape_in_tensorshape(
      tensorshape_util.constant_value_as_shape(input_shape),
      event_shape_in,
      event_shape_out)

  # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function
  # correctly supports control_dependencies.
  validation_dependencies = (
      map(tf.identity, (event_shape_in, event_shape_out))
      if validate_args else ())

  if (tensorshape_util.is_fully_defined(output_tensorshape) and
      (is_validated or not validate_args)):
    with tf.control_dependencies(validation_dependencies):
      output_shape = tf.convert_to_tensor(
          output_tensorshape, name='output_shape', dtype_hint=tf.int32)
    return output_shape, output_tensorshape

  with tf.control_dependencies(validation_dependencies):
    event_shape_in_ndims = (
        tf.size(event_shape_in)
        if tensorshape_util.num_elements(event_shape_in.shape) is None else
        tensorshape_util.num_elements(event_shape_in.shape))
    input_non_event_shape, input_event_shape = tf.split(
        input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

  additional_assertions = []
  if is_validated:
    pass
  elif validate_args:
    # Check that `input_event_shape` and `event_shape_in` are compatible in the
    # sense that they have equal entries in any position that isn't a `-1` in
    # `event_shape_in`. Note that our validations at construction time ensure
    # there is at most one such entry in `event_shape_in`.
    mask = event_shape_in >= 0
    explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask)
    explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask)
    additional_assertions.append(
        assert_util.assert_equal(
            explicit_input_event_shape,
            explicit_event_shape_in,
            message='Input `event_shape` does not match `event_shape_in`.'))
    # We don't explicitly additionally verify
    # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split`
    # already makes this assertion.

  with tf.control_dependencies(additional_assertions):
    output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0,
                             name='output_shape')

  return output_shape, output_tensorshape
示例#23
0
  def __init__(self, event_shape_out, event_shape_in=(-1,),
               validate_args=False, name=None):
    """Creates a `Reshape` bijector.

    Args:
      event_shape_out: An `int`-like vector-shaped `Tensor`
        representing the event shape of the transformed output.
      event_shape_in: An optional `int`-like vector-shape `Tensor`
        representing the event shape of the input. This is required in
        order to define inverse operations; the default of (-1,)
        assumes a vector-shaped input.
      validate_args: Python `bool` indicating whether arguments should
        be checked for correctness.
      name: Python `str`, name given to ops managed by this object.

    Raises:
      TypeError: if either `event_shape_in` or `event_shape_out` has
        non-integer `dtype`.
      ValueError: if either of `event_shape_in` or `event_shape_out`
       has non-vector shape (`rank > 1`), or if their sizes do not
       match.
    """
    with tf.name_scope(name or 'reshape') as name:
      event_shape_out = tf.convert_to_tensor(
          event_shape_out, name='event_shape_out', dtype_hint=tf.int32)
      event_shape_in = tf.convert_to_tensor(
          event_shape_in, name='event_shape_in', dtype_hint=tf.int32)

      forward_min_event_ndims_ = tensorshape_util.num_elements(
          event_shape_in.shape)
      if forward_min_event_ndims_ is None:
        raise NotImplementedError(
            '`event_shape_in` `size` must be statically known. For dynamic '
            'support, please contact `[email protected]`.')

      inverse_min_event_ndims_ = tensorshape_util.num_elements(
          event_shape_out.shape)
      if inverse_min_event_ndims_ is None:
        raise NotImplementedError(
            '`event_shape_out` `size` must be statically known. For dynamic '
            'support, please contact `[email protected]`.')

      assertions = []
      assertions.extend(_maybe_check_valid_shape(
          event_shape_out, validate_args))
      assertions.extend(_maybe_check_valid_shape(
          event_shape_in, validate_args))

      if assertions:
        with tf.control_dependencies(assertions):
          event_shape_in = tf.identity(
              event_shape_in, name='validated_event_shape_in')
          event_shape_out = tf.identity(
              event_shape_out, name='validated_event_shape_out')

      self._event_shape_in = event_shape_in
      self._event_shape_out = event_shape_out

      super(Reshape, self).__init__(
          forward_min_event_ndims=forward_min_event_ndims_,
          inverse_min_event_ndims=inverse_min_event_ndims_,
          is_constant_jacobian=True,
          validate_args=validate_args,
          name=name or 'reshape')
示例#24
0
    def _sample_n(self, n, seed=None):
        stream = seed_stream.SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                input_tensor=self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:])
        if stride is None:
            stride = tf.reduce_prod(input_tensor=tf.shape(
                input=self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
def _safe_tensor_scatter_nd_update(tensor, indices, updates):
  if tensorshape_util.num_elements(tensor.shape) == 0:
    return tensor
  return tf.tensor_scatter_nd_update(tensor, indices, updates)
示例#26
0
def batch_reshapes(draw,
                   batch_shape=None,
                   event_dim=None,
                   enable_vars=False,
                   depth=None,
                   eligibility_filter=lambda name: True,
                   validate_args=True):
    """Strategy for drawing `BatchReshape` distributions.

  The underlying distribution is drawn from the `distributions` strategy.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      `BatchReshape` distribution.  Note that the underlying distribution will
      in general have a different batch shape, to make the reshaping
      non-trivial.  Hypothesis will pick one if omitted.
    event_dim: Optional Python int giving the size of each of the underlying
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor`
      `tfp.util.TransformedVariable`}
    depth: Python `int` giving maximum nesting depth of compound Distributions.
    eligibility_filter: Optional Python callable.  Blocks some Distribution
      class names so they will not be drawn.
    validate_args: Python `bool`; whether to enable runtime assertions.

  Returns:
    dists: A strategy for drawing `BatchReshape` distributions with the
      specified `batch_shape` (or an arbitrary one if omitted).
  """
    if depth is None:
        depth = draw(depths())

    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes(min_ndims=1, max_side=4))

    # TODO(b/142135119): Wanted to draw general input and output shapes like the
    # following, but Hypothesis complained about filtering out too many things.
    # underlying_batch_shape = draw(tfp_hps.shapes(min_ndims=1))
    # hp.assume(
    #   batch_shape.num_elements() == underlying_batch_shape.num_elements())
    underlying_batch_shape = [tensorshape_util.num_elements(batch_shape)]

    underlying = draw(
        distributions(batch_shape=underlying_batch_shape,
                      event_dim=event_dim,
                      enable_vars=enable_vars,
                      depth=depth - 1,
                      eligibility_filter=eligibility_filter,
                      validate_args=validate_args))
    hp.note('Forming BatchReshape with underlying dist {}; '
            'parameters {}; batch_shape {}'.format(underlying,
                                                   params_used(underlying),
                                                   batch_shape))
    result_dist = tfd.BatchReshape(underlying,
                                   batch_shape=batch_shape,
                                   validate_args=True)
    return result_dist
示例#27
0
def _event_size(d):
    if tensorshape_util.num_elements(d.event_shape) is not None:
        return tensorshape_util.num_elements(d.event_shape)
    return tf.reduce_prod(d.event_shape_tensor())