示例#1
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []

    sample_shape = tf.concat(
        [self._batch_shape_tensor(), self._event_shape_tensor()], axis=0)

    low = None if self._low is None else tf.convert_to_tensor(self._low)
    high = None if self._high is None else tf.convert_to_tensor(self._high)

    assertions = []
    if self._low is not None and is_init != tensor_util.is_ref(self._low):
      low_shape = ps.shape(low)
      broadcast_shape = ps.broadcast_shape(sample_shape, low_shape)
      assertions.extend(
          [distribution_util.assert_integer_form(
              low, message='`low` has non-integer components.'),
           assert_util.assert_equal(
               tf.reduce_prod(broadcast_shape),
               tf.reduce_prod(sample_shape),
               message=('Shape of `low` adds extra batch dimensions to '
                        'sample shape.'))])
    if self._high is not None and is_init != tensor_util.is_ref(self._high):
      high_shape = ps.shape(high)
      broadcast_shape = ps.broadcast_shape(sample_shape, high_shape)
      assertions.extend(
          [distribution_util.assert_integer_form(
              high, message='`high` has non-integer components.'),
           assert_util.assert_equal(
               tf.reduce_prod(broadcast_shape),
               tf.reduce_prod(sample_shape),
               message=('Shape of `high` adds extra batch dimensions to '
                        'sample shape.'))])
    if (self._low is not None and self._high is not None and
        (is_init != (tensor_util.is_ref(self._low)
                     or tensor_util.is_ref(self._high)))):
      assertions.append(assert_util.assert_less(
          low, high,
          message='`low` must be strictly less than `high`.'))

    return assertions
示例#2
0
    def _calculate_mean_and_var(self, x, axes, keep_dims):

        with backend.name_scope('moments'):
            # The dynamic range of fp16 is too limited to support the collection of
            # sufficient statistics. As a workaround we simply perform the operations
            # on 32-bit floats before converting the mean and variance back to fp16
            y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x
            replica_ctx = tf.distribute.get_replica_context()
            if replica_ctx:
                local_sum = tf.reduce_sum(y, axis=axes, keepdims=True)
                local_squared_sum = tf.reduce_sum(tf.square(y),
                                                  axis=axes,
                                                  keepdims=True)
                batch_size = tf.cast(tf.shape(y)[axes[0]], tf.float32)
                # TODO(b/163099951): batch the all-reduces once we sort out the ordering
                # issue for NCCL. We don't have a mechanism to launch NCCL in the same
                # order in each replica nowadays, so we limit NCCL to batch all-reduces.
                y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
                                               local_sum)
                y_squared_sum = replica_ctx.all_reduce(
                    tf.distribute.ReduceOp.SUM, local_squared_sum)
                global_batch_size = replica_ctx.all_reduce(
                    tf.distribute.ReduceOp.SUM, batch_size)

                axes_vals = [(tf.shape(y))[axes[i]]
                             for i in range(1, len(axes))]
                multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
                multiplier = multiplier * global_batch_size

                mean = y_sum / multiplier
                y_squared_mean = y_squared_sum / multiplier
                # var = E(x^2) - E(x)^2
                variance = y_squared_mean - tf.square(mean)
            else:
                # Compute true mean while keeping the dims for proper broadcasting.
                mean = tf.reduce_mean(y, axes, keepdims=True, name='mean')
                # sample variance, not unbiased variance
                # Note: stop_gradient does not change the gradient that gets
                #       backpropagated to the mean from the variance calculation,
                #       because that gradient is zero
                variance = tf.reduce_mean(tf.math.squared_difference(
                    y, tf.stop_gradient(mean)),
                                          axes,
                                          keepdims=True,
                                          name='variance')
            if not keep_dims:
                mean = tf.compat.v1.squeeze(mean, axes)
                variance = tf.compat.v1.squeeze(variance, axes)
            if x.dtype == tf.float16:
                return (tf.cast(mean,
                                tf.float16), tf.cast(variance, tf.float16))
            else:
                return (mean, variance)
    def _entropy(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError('`entropy` is not implemented.')
        if not self.bijector._is_injective:  # pylint: disable=protected-access
            raise NotImplementedError('`entropy` is not implemented when '
                                      '`bijector` is not injective.')
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        override_event_shape = tf.convert_to_tensor(self._override_event_shape)
        override_batch_shape = tf.convert_to_tensor(self._override_batch_shape)
        base_batch_shape_tensor = self.distribution.batch_shape_tensor()
        base_event_shape_tensor = self.distribution.event_shape_tensor()
        # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
        # can be shown that:
        #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
        # If is_constant_jacobian then:
        #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
        # where c can by anything.
        entropy = self.distribution.entropy(**distribution_kwargs)
        if self._is_maybe_event_override:
            # H[X] = sum_i H[X_i] if X_i are mutually independent.
            # This means that a reduce_sum is a simple rescaling.
            entropy = entropy * tf.cast(tf.reduce_prod(override_event_shape),
                                        dtype=dtype_util.base_dtype(
                                            entropy.dtype))
        if self._is_maybe_batch_override:
            new_shape = tf.concat([
                prefer_static.ones_like(override_batch_shape),
                base_batch_shape_tensor
            ], 0)
            entropy = tf.reshape(entropy, new_shape)
            multiples = tf.concat([
                override_batch_shape,
                prefer_static.ones_like(base_batch_shape_tensor)
            ], 0)
            entropy = tf.tile(entropy, multiples)

        # Create a dummy event of zeros to pass to
        # `bijector.inverse_log_det_jacobian` to extract the constant Jacobian.
        event_shape_tensor = self._event_shape_tensor(override_event_shape,
                                                      base_event_shape_tensor)
        event_ndims = tf.nest.map_structure(prefer_static.rank_from_shape,
                                            event_shape_tensor,
                                            self.event_shape)
        dummy = tf.nest.map_structure(prefer_static.zeros, event_shape_tensor,
                                      self.dtype)

        ildj = self.bijector.inverse_log_det_jacobian(dummy,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)

        entropy = entropy - tf.cast(ildj, entropy.dtype)
        tensorshape_util.set_shape(entropy, self.batch_shape)
        return entropy
示例#4
0
def get_jacobian_fn_mat(jacobian_fn, ode_fn_vec, state_shape, use_pfor):
    """Returns a wrapper around the user-specified `jacobian_fn` argument.

  `jacobian_fn` is an optional argument that can either be a constant `Tensor`
  or a function of the form `jacobian_fn(time, state)`. This function returns a
  wrapper `jacobian_fn_mat(time, state_vec)` whose second argument and output
  are 1 and 2-D `Tensor`s, respectively, corresponding reshaped versions of
  `state` and `jacobian_fn(time, state)`.

  Args:
    jacobian_fn: User-specified `jacobian_fn` passed to `solve`.
    ode_fn_vec: User-specified `ode_fn` passed to `solve`.
    state_shape: The shape of the second argument and output of `ode_fn`.
    use_pfor: User-specified `use_pfor` passed to `solve`.

  Returns:
    The wrapper described above.
  """
    if jacobian_fn is None:

        def automatic_jacobian_fn_mat(time, state_vec):
            with tf.GradientTape(watch_accessed_variables=False,
                                 persistent=not use_pfor) as tape:
                tape.watch(state_vec)
                outputs = ode_fn_vec(time, state_vec)
            jacobian_mat = tape.jacobian(outputs,
                                         state_vec,
                                         experimental_use_pfor=use_pfor)
            if jacobian_mat is None:
                return tf.zeros([tf.size(state_vec)] * 2,
                                dtype=state_vec.dtype)
            return jacobian_mat

        return automatic_jacobian_fn_mat

    if not callable(jacobian_fn):
        constant_jacobian_mat = tf.reshape(
            tf.convert_to_tensor(jacobian_fn),
            [-1, tf.reduce_prod(state_shape)])

        def constant_jacobian_fn_mat(*_):
            return constant_jacobian_mat

        return constant_jacobian_fn_mat

    def jacobian_fn_mat(time, state_vec):
        state = tf.reshape(state_vec, state_shape)
        jacobian_mat = tf.reshape(jacobian_fn(time, state),
                                  [-1, tf.size(state)])
        return jacobian_mat

    return jacobian_fn_mat
示例#5
0
  def _event_shape_tensor(self):
    event_sizes = tf.nest.map_structure(tensorshape_util.num_elements,
                                        self._distribution.event_shape)

    if any(s is None for s in tf.nest.flatten(event_sizes)):
      event_sizes = tf.nest.map_structure(
          lambda static_size, shape_tensor:  # pylint: disable=g-long-lambda
          (tf.reduce_prod(shape_tensor)
           if static_size is None else static_size),
          event_sizes,
          self._distribution.event_shape_tensor())

    return tf.reduce_sum(tf.nest.flatten(event_sizes))[tf.newaxis]
示例#6
0
def _get_leftmost_dim_size(x, name=None):
  """Returns the size of the left most dimension, statically if possible."""
  with tf.name_scope(name or 'get_leftmost_dim_size'):
    x = tf.convert_to_tensor(value=x, name='x')
    if x.shape.ndims is None:
      # If tf.shape(x) is scalar, the [:1] will produce the empty list, whose
      # reduce_prod is 1 as desired.  Otherwise, the [:1] will select the first
      # dimension, and reduce_prod will not alter it.
      return tf.reduce_prod(input_tensor=tf.shape(input=x)[:1])
    if x.shape.ndims == 0:
      return 1
    leftmost = tf.compat.dimension_value(x.shape[0])
    return leftmost if leftmost is not None else tf.shape(input=x)[0]
示例#7
0
def make_2d(tensor, split_dim):
    """Reshapes an N-dimensional tensor into a 2D tensor.

  Dimensions before (excluding) and after (including) `split_dim` are grouped
  together.

  Args:
    tensor: a tensor of shape `(d0, ..., d(N-1))`.
    split_dim: an integer from 1 to N-1, index of the dimension to group
      dimensions before (excluding) and after (including).

  Returns:
    Tensor of shape
    `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`.
  """
    shape = tf.compat.v1.shape(tensor)
    in_dims = shape[:split_dim]
    out_dims = shape[split_dim:]

    in_size = tf.reduce_prod(in_dims)
    out_size = tf.reduce_prod(out_dims)

    return tf.reshape(tensor, (in_size, out_size))
示例#8
0
    def _mean(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size num_states observation_event_size
            flat_mean = tf.einsum("ijk,jkl->jil", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps observation_event_size
            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)
            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_mean, unflat_mean_shape)
 def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims,
                                **distribution_kwargs):
   """Finish computation of prob on one element of the inverse image."""
   x = self._maybe_rotate_dims(x, rotate_right=True)
   prob = self.distribution.prob(x, **distribution_kwargs)
   if self._is_maybe_event_override:
     prob = tf.reduce_prod(prob, axis=self._reduce_event_indices)
   prob = prob * tf.exp(tf.cast(ildj, prob.dtype))
   if self._is_maybe_event_override and isinstance(event_ndims, int):
     tensorshape_util.set_shape(
         prob,
         tf.broadcast_static_shape(
             tensorshape_util.with_rank_at_least(y.shape, 1)[:-event_ndims],
             self.batch_shape))
   return prob
示例#10
0
    def _mean(self):
        observation_distribution = self.observation_distribution
        batch_shape = self.batch_shape_tensor()
        num_states = self.transition_distribution.batch_shape_tensor()[-1]
        probs = self._marginal_hidden_probs()
        # probs :: num_steps batch_shape num_states
        means = observation_distribution.mean()
        # means :: observation_batch_shape[:-1] num_states
        #          observation_event_shape
        means_shape = tf.concat([
            batch_shape, [num_states],
            observation_distribution.event_shape_tensor()
        ],
                                axis=0)
        means = tf.broadcast_to(means, means_shape)
        # means :: batch_shape num_states observation_event_shape

        observation_event_shape = (
            observation_distribution.event_shape_tensor())
        batch_size = tf.reduce_prod(batch_shape)
        flat_probs_shape = [self._num_steps, batch_size, num_states]
        flat_means_shape = [
            batch_size, num_states,
            tf.reduce_prod(observation_event_shape)
        ]

        flat_probs = tf.reshape(probs, flat_probs_shape)
        # flat_probs :: num_steps batch_size num_states
        flat_means = tf.reshape(means, flat_means_shape)
        # flat_means :: batch_size num_states observation_event_size
        flat_mean = tf.einsum('ijk,jkl->jil', flat_probs, flat_means)
        # flat_mean :: batch_size num_steps observation_event_size
        unflat_mean_shape = tf.concat(
            [batch_shape, [self._num_steps], observation_event_shape], axis=0)
        # returns :: batch_shape num_steps observation_event_shape
        return tf.reshape(flat_mean, unflat_mean_shape)
示例#11
0
    def test_docstring_example(self):
        # Produce the first 1000 members of the Halton sequence in 3 dimensions.
        num_results = 1000
        dim = 3
        sample, params = random.halton.sample(dim,
                                              num_results=num_results,
                                              seed=127)

        # Evaluate the integral of x_1 * x_2^2 * x_3^3  over the three dimensional
        # hypercube.
        powers = tf.range(1., limit=dim + 1)
        integral = tf.reduce_mean(
            input_tensor=tf.reduce_prod(input_tensor=sample**powers, axis=-1))
        true_value = 1. / tf.reduce_prod(input_tensor=powers + 1.)

        # Produces a relative absolute error of 1.7%.
        self.assertAllClose(self.evaluate(integral),
                            self.evaluate(true_value),
                            rtol=0.02)

        # Now skip the first 1000 samples and recompute the integral with the next
        # thousand samples. The sequence_indices argument can be used to do this.

        sequence_indices = tf.range(start=1000,
                                    limit=1000 + num_results,
                                    dtype=tf.int32)
        sample_leaped, _ = random.halton.sample(
            dim,
            sequence_indices=sequence_indices,
            randomization_params=params)

        integral_leaped = tf.reduce_mean(input_tensor=tf.reduce_prod(
            input_tensor=sample_leaped**powers, axis=-1))
        self.assertAllClose(self.evaluate(integral_leaped),
                            self.evaluate(true_value),
                            rtol=0.05)
示例#12
0
def _make_flatten_unflatten_fns_tf(batch_shape):
  """Returns functions to flatten and unflatten a batch shape."""
  batch_shape = tf.cast(batch_shape, dtype=tf.int32)
  batch_rank = batch_shape.shape[0]
  batch_ndims = tf.reduce_prod(batch_shape)

  @tf.function
  def flatten_fn(x):
    flat_shape = tf.concat([[batch_ndims], tf.shape(x)[batch_rank:]], axis=0)
    return tf.reshape(x, flat_shape)

  @tf.function
  def unflatten_fn(x):
    full_shape = tf.concat([batch_shape, tf.shape(x)[1:]], axis=0)
    return tf.reshape(x, full_shape)
  return flatten_fn, unflatten_fn
示例#13
0
    def _mode(self, samples=None):
        # Samples count can vary by batch member. Use map_fn to compute mode for
        # each batch separately.
        def _get_mode(samples):
            _, idx, count = tf.raw_ops.UniqueWithCountsV2(x=samples, axis=[0])
            # TODO(b/161402486): Remove this hack for fixing the wrong static shape
            # of `idx` in graph mode.
            idx = tf.vectorized_map(lambda x: tf.reshape(x, [-1])[0], idx)
            # NOTE:
            #  - `count` has shape `[K]`, where `K` is the number of unique elements,
            #    and `count[j]` is the number of times the j-th unique element occurs
            #    in `samples`.
            #  - `idx` has shape `[samples.shape[0]]`, and `idx[i] == j` means that
            #    `samples[i]` is equal to the `j`-th unique element.
            max_count_idx = tf.argmax(count, output_type=tf.int32)
            # Return an index `i` for which `idx[i] == max_count_idx`.
            return tf.argmax(tf.cast(tf.math.equal(idx, max_count_idx),
                                     dtype=tf.int32),
                             output_type=tf.int32)

        if samples is None:
            samples = tf.convert_to_tensor(self._samples)
        num_samples = self._compute_num_samples(samples)

        # Flatten samples for each batch.
        if self._event_ndims == 0:
            flattened_samples = tf.reshape(samples, [-1, num_samples])
            mode_shape = self._batch_shape_tensor(samples)
        else:
            event_size = tf.reduce_prod(self._event_shape_tensor(samples))
            mode_shape = ps.concat([
                self._batch_shape_tensor(samples),
                self._event_shape_tensor(samples)
            ],
                                   axis=0)
            flattened_samples = tf.reshape(samples,
                                           [-1, num_samples, event_size])

        indices = tf.map_fn(_get_mode,
                            flattened_samples,
                            fn_output_signature=tf.int32)
        full_indices = tf.stack([tf.range(tf.shape(indices)[0]), indices],
                                axis=1)

        mode = tf.gather_nd(flattened_samples, full_indices)
        return tf.reshape(mode, mode_shape)
  def _entropy(self, **kwargs):
    if not self.bijector.is_constant_jacobian:
      raise NotImplementedError("entropy is not implemented")
    if not self.bijector._is_injective:  # pylint: disable=protected-access
      raise NotImplementedError("entropy is not implemented when "
                                "bijector is not injective.")
    distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
    # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
    # can be shown that:
    #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
    # If is_constant_jacobian then:
    #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
    # where c can by anything.
    entropy = self.distribution.entropy(**distribution_kwargs)
    if self._is_maybe_event_override:
      # H[X] = sum_i H[X_i] if X_i are mutually independent.
      # This means that a reduce_sum is a simple rescaling.
      entropy = entropy * tf.cast(
          tf.reduce_prod(self._override_event_shape),
          dtype=dtype_util.base_dtype(entropy.dtype))
    if self._is_maybe_batch_override:
      new_shape = tf.concat([
          prefer_static.ones_like(self._override_batch_shape),
          self.distribution.batch_shape_tensor()
      ], 0)
      entropy = tf.reshape(entropy, new_shape)
      multiples = tf.concat([
          self._override_batch_shape,
          prefer_static.ones_like(self.distribution.batch_shape_tensor())
      ], 0)
      entropy = tf.tile(entropy, multiples)
    dummy = prefer_static.zeros(
        shape=tf.concat(
            [self.batch_shape_tensor(), self.event_shape_tensor()],
            0),
        dtype=self.dtype)
    event_ndims = (
        tensorshape_util.rank(self.event_shape)
        if tensorshape_util.rank(self.event_shape) is not None else tf.size(
            self.event_shape_tensor()))
    ildj = self.bijector.inverse_log_det_jacobian(
        dummy, event_ndims=event_ndims, **bijector_kwargs)

    entropy = entropy - tf.cast(ildj, entropy.dtype)
    tensorshape_util.set_shape(entropy, self.batch_shape)
    return entropy
示例#15
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    distributions = self.poisson_and_mixture_distributions()
    dist, mixture_dist = distributions
    batch_size = tensorshape_util.num_elements(self.batch_shape)
    if batch_size is None:
      batch_size = tf.reduce_prod(
          self._batch_shape_tensor(distributions=distributions))
    # We need to 'sample extra' from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    mixture_seed, poisson_seed = samplers.split_seed(
        seed, salt='PoissonLogNormalQuadratureCompound')
    ids = mixture_dist.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                mixture_dist.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=mixture_seed)
    # We need to flatten batch dims in case mixture_dist has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = tf.range(
        start=0,
        limit=batch_size * self._quadrature_size,
        delta=self._quadrature_size,
        dtype=ids.dtype)
    ids = ids + offset
    rate = tf.gather(tf.reshape(dist.rate_parameter(), shape=[-1]), ids)
    rate = tf.reshape(
        rate, shape=concat_vectors([n], self._batch_shape_tensor(
            distributions=distributions)))
    return samplers.poisson(
        shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
示例#16
0
    def _sample_n(self, n, seed=None):
        distribution0 = self._get_distribution0()

        if self._num_steps is not None:
            num_steps = tf.convert_to_tensor(self._num_steps)
            num_steps_static = tf.get_static_value(num_steps)
        else:
            num_steps_static = tensorshape_util.num_elements(
                distribution0.event_shape)
            if num_steps_static is None:
                num_steps = tf.reduce_prod(distribution0.event_shape_tensor())

        stateless_seed = samplers.sanitize_seed(seed, salt='Autoregressive')
        stateful_seed = None
        try:
            samples = distribution0.sample(n, seed=stateless_seed)
            is_stateful_sampler = False
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            msg = (
                'Falling back to stateful sampling for `distribution_fn(sample0)` of '
                'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                'This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(distribution0.name, type(distribution0), str(e)))
            stateful_seed = SeedStream(seed, salt='Autoregressive')()
            samples = distribution0.sample(n, seed=stateful_seed)
            is_stateful_sampler = True

        seed = stateful_seed if is_stateful_sampler else stateless_seed

        if num_steps_static is not None:
            for _ in range(num_steps_static):
                # pylint: disable=not-callable
                samples = self.distribution_fn(samples).sample(seed=seed)
        else:
            # pylint: disable=not-callable
            samples = tf.foldl(
                lambda s, _: self.distribution_fn(s).sample(seed=seed),
                elems=tf.range(0, num_steps),
                initializer=samples)
        return samples
示例#17
0
  def __call__(self, x):
    """Computes regularization given an input ed.RandomVariable."""
    if not isinstance(x, random_variable.RandomVariable):
      raise ValueError('Input must be an ed.RandomVariable.')
    # variance = (tr( sigma_q + mu_q mu_q^T ) + 2*beta) / (omega + 2*alpha + 2)
    trace_covariance = tf.reduce_sum(x.distribution.variance())
    trace_mean_outer_product = tf.reduce_sum(x.distribution.mean()**2)
    num_weights = tf.cast(tf.reduce_prod(x.shape), x.dtype)
    variance = ((trace_covariance + trace_mean_outer_product) +
                2. * self.variance_scale)
    variance /= num_weights + 2. * self.variance_concentration + 2.
    self.stddev = tf.sqrt(variance)

    variance_prior = generated_random_variables.InverseGamma(
        self.variance_concentration, self.variance_scale)
    regularization = super(NormalEmpiricalBayesKLDivergence, self).__call__(x)
    regularization -= (self.scale_factor *
                       variance_prior.distribution.log_prob(variance))
    return regularization
    def _forward(self, x, **kwargs):
        static_event_size = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(
                x.shape, self._event_ndims)[-self._event_ndims:])

        if self._unroll_loop:
            if not static_event_size:
                raise ValueError(
                    'The final {} dimensions of `x` must be known at graph '
                    'construction time if `unroll_loop=True`. `x.shape: {!r}`'.
                    format(self._event_ndims, x.shape))
            y = tf.zeros_like(x, name='y0')

            for _ in range(static_event_size):
                y = self._bijector_fn(y, **kwargs).forward(x)
            return y

        event_size = tf.reduce_prod(tf.shape(x)[-self._event_ndims:])
        y0 = tf.zeros_like(x, name='y0')
        # call the template once to ensure creation
        if not tf.executing_eagerly():
            _ = self._bijector_fn(y0, **kwargs).forward(y0)

        def _loop_body(index, y0):
            """While-loop body for autoregression calculation."""
            # Set caching device to avoid re-getting the tf.Variable for every while
            # loop iteration.
            with tf1.variable_scope(tf1.get_variable_scope()) as vs:
                if vs.caching_device is None and not tf.executing_eagerly():
                    vs.set_caching_device(lambda op: op.device)
                bijector = self._bijector_fn(y0, **kwargs)
            y = bijector.forward(x)
            return index + 1, y

        # If the event size is available at graph construction time, we can inform
        # the graph compiler of the maximum number of steps. If not,
        # static_event_size will be None, and the maximum_iterations argument will
        # have no effect.
        _, y = tf.while_loop(cond=lambda index, _: index < event_size,
                             body=_loop_body,
                             loop_vars=(0, y0),
                             maximum_iterations=static_event_size)
        return y
示例#19
0
    def update_state(self, data):
        if self.input_mean is not None:
            raise ValueError(
                "Cannot `adapt` a Normalization layer that is initialized with "
                "static `mean` and `variance`, "
                "you passed mean {} and variance {}.".format(
                    self.input_mean, self.input_variance
                )
            )

        if not self.built:
            raise RuntimeError("`build` must be called before `update_state`.")

        data = self._standardize_inputs(data)
        data = tf.cast(data, self.adapt_mean.dtype)
        batch_mean, batch_variance = tf.nn.moments(data, axes=self._reduce_axis)
        batch_shape = tf.shape(data, out_type=self.count.dtype)
        if self._reduce_axis:
            batch_reduce_shape = tf.gather(batch_shape, self._reduce_axis)
            batch_count = tf.reduce_prod(batch_reduce_shape)
        else:
            batch_count = 1

        total_count = batch_count + self.count
        batch_weight = tf.cast(batch_count, dtype=self.compute_dtype) / tf.cast(
            total_count, dtype=self.compute_dtype
        )
        existing_weight = 1.0 - batch_weight

        total_mean = (
            self.adapt_mean * existing_weight + batch_mean * batch_weight
        )
        # The variance is computed using the lack-of-fit sum of squares
        # formula (see
        # https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
        total_variance = (
            self.adapt_variance + (self.adapt_mean - total_mean) ** 2
        ) * existing_weight + (
            batch_variance + (batch_mean - total_mean) ** 2
        ) * batch_weight
        self.adapt_mean.assign(total_mean)
        self.adapt_variance.assign(total_variance)
        self.count.assign(total_count)
示例#20
0
    def _split_and_reshape_event(self, x):
        """Splits and reshapes of a vector-valued event `x`."""
        splits = [
            tf.maximum(1, tf.reduce_prod(s))
            for s in tf.nest.flatten(self._target_density.event_shape)
        ]
        x = tf.nest.pack_sequence_as(self._target_density.event_shape,
                                     tf.split(x, splits, axis=-1))

        def _reshape_part(part, dtype, event_shape):
            part = tf.cast(part, dtype)
            rank = event_shape.rank
            if rank == 1:
                return part
            new_shape = tf.concat([tf.shape(part)[:-1], event_shape], axis=-1)
            return tf.reshape(part, tf.cast(new_shape, tf.int32))

        x = tf.nest.map_structure(_reshape_part, x, self._target_density.dtype,
                                  self._target_density.event_shape)
        return x
    def sample(self, sample_shape=(), seed=None, name=None):
        with tf.name_scope(name or 'sample'):
            # Grab the required number of values from the provided tensors.
            sample_shape = dist_util.expand_to_vector(sample_shape)
            n = tf.cast(tf.reduce_prod(sample_shape), dtype=tf.int32)

            # Check that we're not trying to draw too many samples.
            assertions = []
            will_overflow_ = tf.get_static_value(n > self.max_num_samples)
            if will_overflow_:
                raise ValueError(
                    'Trying to draw {} samples from a '
                    '`DeterministicEmpirical` instance for which only {} '
                    'samples were provided.'.format(
                        tf.get_static_value(n),
                        tf.get_static_value(self.max_num_samples)))
            elif (will_overflow_ is None  # Couldn't determine statically.
                  and self.validate_args):
                assertions.append(
                    tf.debugging.assert_less_equal(
                        n,
                        self.max_num_samples,
                        message='Number of samples to draw '
                        'from a `DeterministicEmpirical` instance must not exceed the '
                        'number provided at construction.'))

            # Extract the appropriate number of sampled values.
            with tf.control_dependencies(assertions):
                sampled = tf.nest.map_structure(lambda x: x[:n, ...],
                                                self.values_with_sample_dim)

            # Reshape the values to the appropriate sample shape.
            return tf.nest.map_structure(
                lambda x: tf.reshape(
                    x,  # pylint: disable=g-long-lambda
                    tf.concat([
                        tf.cast(sample_shape, tf.int32),
                        tf.cast(tf.shape(x)[1:], tf.int32)
                    ],
                              axis=0)),
                sampled)
示例#22
0
  def _entropy(self):
    # Use map_fn to compute entropy for each batch separately.
    def _get_entropy(samples):
      # TODO(b/123985779): Swith to tf.unique_with_counts_v2 when exposed
      count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count
      prob = count / self.num_samples
      entropy = tf.reduce_sum(input_tensor=-prob * tf.math.log(prob))
      return entropy

    # Flatten samples for each batch.
    if self._event_ndims == 0:
      samples = tf.reshape(self.samples, [-1, self.num_samples])
    else:
      event_size = tf.reduce_prod(input_tensor=self.event_shape_tensor())
      samples = tf.reshape(self.samples, [-1, self.num_samples, event_size])

    entropy = tf.map_fn(_get_entropy, samples)
    entropy_shape = self.batch_shape_tensor()
    if self.dtype.is_floating:
      entropy = tf.cast(entropy, self.dtype)
    return tf.reshape(entropy, entropy_shape)
示例#23
0
    def basis(sample_paths):
        """Computes polynomial basis expansion at the given sample points.

    Args:
      sample_paths: A `Tensor`s of either `flot32` or `float64` dtype and of
        shape `[num_samples, dim]` where `dim` has to be statically known.

    Returns:
      A `Tensor`s of shape `[degree * dim, num_samples]`.
    """
        samples = tf.convert_to_tensor(sample_paths)
        dim = samples.shape.as_list()[-1]
        grid = tf.range(0, degree + 1, dtype=samples.dtype)

        samples_centered = samples - tf.math.reduce_mean(samples, axis=0)
        samples_centered = tf.expand_dims(samples_centered, -2)
        grid = tf.meshgrid(*(dim * [grid]))
        grid = tf.reshape(tf.stack(grid, -1), [-1, dim])
        # Shape [num_samples, degree * dim]
        basis_expansion = tf.reduce_prod(samples_centered**grid, -1)
        return tf.transpose(basis_expansion)
示例#24
0
    def _entropy(self):
        samples = tf.convert_to_tensor(self.samples)
        num_samples = self._compute_num_samples(samples)
        entropy_shape = self._batch_shape_tensor(samples)

        # Flatten samples for each batch.
        if self._event_ndims == 0:
            samples = tf.reshape(samples, [-1, num_samples])
        else:
            event_size = tf.reduce_prod(self.event_shape_tensor())
            samples = tf.reshape(samples, [-1, num_samples, event_size])

        # Use map_fn to compute entropy for each batch separately.
        def _get_entropy(samples):
            count = tf.raw_ops.UniqueWithCountsV2(x=samples, axis=[0]).count
            prob = tf.cast(count / num_samples, dtype=self.dtype)
            entropy = tf.reduce_sum(-prob * tf.math.log(prob))
            return entropy

        entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype)
        return tf.reshape(entropy, entropy_shape)
示例#25
0
    def _sample_n(self, n, seed=None):
        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = self.batch_shape.num_elements()
        if batch_size is None:
            batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor())
        # We need to "sample extra" from the mixture distribution if it doesn't
        # already specify a probs vector for each batch coordinate.
        # We only support this kind of reduced broadcasting, i.e., there is exactly
        # one probs vector for all batch dims or one for each.
        stream = seed_stream.SeedStream(
            seed, salt="PoissonLogNormalQuadratureCompound")
        ids = self._mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.mixture_distribution.is_scalar_batch(), [batch_size],
                np.int32([]))),
                                                seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `quadrature_size` for `batch_size` number of times.
        offset = tf.range(start=0,
                          limit=batch_size * self._quadrature_size,
                          delta=self._quadrature_size,
                          dtype=ids.dtype)
        ids += offset
        rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids)
        rate = tf.reshape(rate,
                          shape=concat_vectors([n], self.batch_shape_tensor()))
        return tf.random.poisson(lam=rate,
                                 shape=[],
                                 dtype=self.dtype,
                                 seed=seed)
示例#26
0
  def basis(sample_paths, time_index):
    """Computes polynomial basis expansion at the given sample points.

    Args:
      sample_paths: A `Tensor` of either `flaot32` or `float64` dtype and of
        either shape `[num_samples, num_times, dim]` or
        `[batch_size, num_samples, num_times, dim]`.
      time_index: An integer scalar `Tensor` that corresponds to the time
        coordinate at which the basis function is computed.

    Returns:
      A `Tensor`s of shape `[batch_size, (degree + 1)**dim, num_samples]`.
    """
    sample_paths = tf.convert_to_tensor(sample_paths,
                                        name="sample_paths")
    if sample_paths.shape.rank == 3:
      sample_paths = tf.expand_dims(sample_paths, axis=0)
    shape = tf.shape(sample_paths)
    num_samples = shape[1]
    batch_size = shape[0]
    dim = sample_paths.shape[-1]  # Dimension should statically known
    # Shape [batch_size, num_samples, 1, dim]
    slice_samples = tf.slice(sample_paths, [0, 0, time_index, 0],
                             [batch_size, num_samples, 1, dim])
    # Shape [batch_size, num_samples, 1, dim]
    samples_centered = slice_samples - tf.math.reduce_mean(
        slice_samples, axis=1, keepdims=True)
    grid = tf.range(degree + 1, dtype=samples_centered.dtype)
    # Creates a grid of 'power' expansions, i.e., a `Tensor` of shape
    # [(degree + 1)**dim, dim] with entries [k_1, .., k_dim] where
    ## 0 <= k_i <= dim.
    grid = tf.meshgrid(*(dim * [grid]))
    # Shape [(degree + 1)**3, dim]
    grid = tf.reshape(tf.stack(grid, -1), [-1, dim])
    # `samples_centered` has shape [batch_size, num_samples, 1, dim],
    # `samples_centered**grid` has shape
    # `[batch_size, num_samples, (degree + 1)**dim, dim]`
    # so that the output shape is `[batch_size, num_samples, (degree + 1)**dim]`
    basis_expansion = tf.reduce_prod(samples_centered**grid, axis=-1)
    return tf.transpose(basis_expansion, [0, 2, 1])
示例#27
0
    def testRank1ResNetV1(self, alpha_initializer, gamma_initializer,
                          random_sign_init, ensemble_size):
        tf.random.set_seed(83922)
        dataset_size = 10
        batch_size = 6
        input_shape = (32, 32, 2
                       )  # TODO(dusenberrymw): (32, 32, 1) doesn't work...
        num_classes = 2

        features = tf.random.normal((dataset_size, ) + input_shape)
        coeffs = tf.random.normal([tf.reduce_prod(input_shape), num_classes])
        net = tf.reshape(features, [dataset_size, -1])
        logits = tf.matmul(net, coeffs)
        labels = tf.random.categorical(logits, 1)
        dataset = tf.data.Dataset.from_tensor_slices((features, labels))
        dataset = dataset.repeat().shuffle(dataset_size).batch(batch_size)

        model = resnet_cifar_model.rank1_resnet_v1(
            input_shape=input_shape,
            depth=8,
            num_classes=num_classes,
            width_multiplier=1,
            alpha_initializer=alpha_initializer,
            gamma_initializer=gamma_initializer,
            alpha_regularizer=None,
            gamma_regularizer=None,
            use_additive_perturbation=False,
            ensemble_size=ensemble_size,
            random_sign_init=-0.5,
            dropout_rate=0.)
        model.compile('adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(
                          from_logits=True))
        history = model.fit(dataset,
                            steps_per_epoch=dataset_size // batch_size,
                            epochs=2)

        loss_history = history.history['loss']
        self.assertAllGreaterEqual(loss_history, 0.)
示例#28
0
def _quasi_uniform(
    dim,
    sample_shape,
    random_type,
    dtype,
    seed=None,
    **kwargs):
  """Quasi random draws from a uniform distribution on [0, 1)."""
  # Shape of the output
  output_shape = tf.concat([sample_shape] + [[dim]], -1)
  # Number of quasi random samples
  num_samples = tf.reduce_prod(sample_shape)
  # Number of initial low discrepancy sequence numbers to skip
  if 'skip' in kwargs:
    skip = kwargs['skip']
  else:
    skip = 0
  if random_type == RandomType.SOBOL:
    # Shape [num_samples, dim] of the Sobol samples
    low_discrepancy_seq = sobol.sample(
        dim=dim, num_results=num_samples, skip=skip,
        dtype=dtype)
    # TODO(b/148005344): Remove after tf.reshape after the bug is fixed
    low_discrepancy_seq = tf.reshape(low_discrepancy_seq, [num_samples, dim])
  else:  # HALTON or HALTON_RANDOMIZED random_dtype
    if 'randomization_params' in kwargs:
      randomization_params = kwargs['randomization_params']
    else:
      randomization_params = None
    randomized = random_type == RandomType.HALTON_RANDOMIZED
    # Shape [num_samples, dim] of the Sobol samples
    low_discrepancy_seq, _ = halton.sample(
        dim=dim,
        sequence_indices=tf.range(skip, skip + num_samples),
        randomized=randomized,
        randomization_params=randomization_params,
        seed=seed,
        dtype=dtype)
  return  tf.reshape(low_discrepancy_seq, output_shape)
示例#29
0
  def _entropy(self):
    samples = tf.convert_to_tensor(self.samples)
    num_samples = self._compute_num_samples(samples)
    entropy_shape = self._batch_shape_tensor(samples)

    # Flatten samples for each batch.
    if self._event_ndims == 0:
      samples = tf.reshape(samples, [-1, num_samples])
    else:
      event_size = tf.reduce_prod(self.event_shape_tensor())
      samples = tf.reshape(samples, [-1, num_samples, event_size])

    # Use map_fn to compute entropy for each batch separately.
    def _get_entropy(samples):
      # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
      count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count
      prob = tf.cast(count / num_samples, dtype=self.dtype)
      entropy = tf.reduce_sum(-prob * tf.math.log(prob))
      return entropy

    entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype)
    return tf.reshape(entropy, entropy_shape)
示例#30
0
        def easom(z):
            """The value of the two dimensional Easom function.

      The Easom function is a standard optimization test function. It has
      a single global minimum at (pi, pi) which is located inside a deep
      funnel. The expression for the function is:

      ```None
      f(x, y) = -cos(x) cos(y) exp(-(x-pi)**2 - (y-pi)**2)
      ```

      Args:
        z: `Tensor` of shape [2] and real dtype. The argument at which to
          evaluate the function.

      Returns:
        value: Scalar real `Tensor`. The value of the Easom function at the
          supplied argument.
      """
            f1 = tf.reduce_prod(tf.cos(z), axis=-1)
            f2 = tf.exp(-tf.reduce_sum((z - np.pi)**2, axis=-1))
            return -f1 * f2