def _sample_n(self, n, seed=None):
    loc = tf.convert_to_tensor(self.loc)
    scale = tf.convert_to_tensor(self.scale)
    tailweight = tf.convert_to_tensor(self.tailweight)
    skewness = tf.convert_to_tensor(self.skewness)
    ig_seed, normal_seed = samplers.split_seed(
        seed, salt='normal_inverse_gaussian')
    batch_shape = self._batch_shape_tensor(
        loc=loc,
        scale=scale,
        tailweight=tailweight,
        skewness=skewness)
    w = tailweight * tf.math.exp(0.5 * tf.math.log1p(
        -tf.math.square(skewness / tailweight)))
    w = tf.broadcast_to(w, batch_shape)
    ig_samples = inverse_gaussian.InverseGaussian(
        scale / w, tf.math.square(scale)).sample(n, seed=ig_seed)

    sample_shape = ps.concat([[n], batch_shape], axis=0)
    normal_samples = samplers.normal(
        shape=ps.convert_to_shape_tensor(sample_shape),
        mean=0., stddev=1., dtype=self.dtype, seed=normal_seed)
    return (loc + tf.math.sqrt(ig_samples) * (
        skewness * tf.math.sqrt(ig_samples) + normal_samples))
Beispiel #2
0
    def _sample_n(self, n, seed=None):
        temperature = tf.convert_to_tensor(self.temperature)
        logits = self._logits_parameter_no_checks()

        # Uniform variates must be sampled from the open-interval `(0, 1)` rather
        # than `[0, 1)`. To do so, we use
        # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny` because it is the
        # smallest, positive, 'normal' number. A 'normal' number is such that the
        # mantissa has an implicit leading 1. Normal, positive numbers x, y have the
        # reasonable property that, `x + y >= max(x, y)`. In this case, a subnormal
        # number (i.e., np.nextafter) can cause us to sample 0.
        uniform_shape = ps.concat(
            [[n],
             self._batch_shape_tensor(temperature=temperature, logits=logits),
             self._event_shape_tensor(logits=logits)], 0)
        uniform = samplers.uniform(
            shape=uniform_shape,
            minval=np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny,
            maxval=1.,
            dtype=self.dtype,
            seed=seed)
        gumbel = -tf.math.log(-tf.math.log(uniform))
        noisy_logits = (gumbel + logits) / temperature[..., tf.newaxis]
        return tf.math.log_softmax(noisy_logits)
  def _sample_n(self, n, seed=None):
    low = tf.convert_to_tensor(self.low)
    high = tf.convert_to_tensor(self.high)
    peak = tf.convert_to_tensor(self.peak)

    seed = samplers.sanitize_seed(seed, salt='triangular')
    shape = ps.concat([[n], self._batch_shape_tensor(
        low=low, high=high, peak=peak)], axis=0)
    samples = samplers.uniform(shape=shape, dtype=self.dtype, seed=seed)
    # We use Inverse CDF sampling here. Because the CDF is a quadratic function,
    # we must use sqrts here.
    interval_length = high - low
    return tf.where(
        # Note the CDF on the left side of the peak is
        # (x - low) ** 2 / ((high - low) * (peak - low)).
        # If we plug in peak for x, we get that the CDF at the peak
        # is (peak - low) / (high - low). Because of this we decide
        # which part of the piecewise CDF we should use based on the cdf samples
        # we drew.
        samples < (peak - low) / interval_length,
        # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)).
        low + tf.sqrt(samples * interval_length * (peak - low)),
        # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak))
        high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
    def _variance(self):
        probs = self.mixture_distribution.probs_parameter()  # [B, k] or [k]
        component_means = self.components_distribution.mean()  # [B, k, E]
        component_vars = self.components_distribution.variance()  # [B, k, E]
        event_ndims = self._event_ndims()

        # reshape probs to [B, k, [1]*e] or [k, [1]*e]
        probs = tf.reshape(
            probs,
            ps.concat(
                [ps.shape(probs),
                 ps.ones([event_ndims], dtype=tf.int32)],
                axis=0))

        # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
        mean_cond_var = tf.reduce_sum(probs * component_vars,
                                      axis=-1 - event_ndims)  # [B, E]
        mean = tf.reduce_sum(probs * component_means,
                             axis=-1 - event_ndims,
                             keepdims=True)  # [B, 1, E]
        var_cond_mean = tf.reduce_sum(
            probs * tf.math.squared_difference(component_means, mean),
            axis=-1 - event_ndims)  # [B, E]
        return mean_cond_var + var_cond_mean
Beispiel #5
0
    def _sample_n(self, n, seed):
        df = tf.convert_to_tensor(self.df)
        batch_shape = self._batch_shape_tensor(df)
        event_shape = self._event_shape_tensor()
        batch_ndims = ps.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = ps.concat([[n], batch_shape, event_shape], 0)
        normal_seed, gamma_seed = samplers.split_seed(seed, salt='Wishart')

        # Complexity: O(nbk**2)
        x = samplers.normal(shape=shape,
                            mean=0.,
                            stddev=1.,
                            dtype=self.dtype,
                            seed=normal_seed)

        # Complexity: O(nbk)
        # This parameterization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = df * tf.ones(self._scale.batch_shape_tensor(),
                                   dtype=dtype_util.base_dtype(df.dtype))

        g = gamma_lib.random_gamma(
            shape=[n],
            concentration=self._multi_gamma_sequence(0.5 * expanded_df,
                                                     self._dimension()),
            log_rate=tf.convert_to_tensor(np.log(0.5), self.dtype),
            seed=gamma_seed,
            log_space=True)

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.math.exp(g * 0.5))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = ps.concat([ps.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = ps.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self._scale.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = ps.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = ps.concat([[ndims - 1], ps.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x
    def __init__(
            self,
            input_size,
            output_size,  # keras::Conv::filters
            # Conv specific.
        filter_shape,  # keras::Conv::kernel_size
            rank=2,  # keras::Conv::rank
            strides=1,  # keras::Conv::strides
            padding='VALID',  # keras::Conv::padding; 'CAUSAL' not implemented.
            # keras::Conv::data_format is not implemented
        dilations=1,  # keras::Conv::dilation_rate
            # Weights
        init_kernel_fn=None,  # tfp.experimental.nn.initializers.glorot_uniform()
            init_bias_fn=None,  # tf.initializers.zeros()
            make_kernel_bias_fn=nn_util_lib.make_kernel_bias,
            dtype=tf.float32,
            batch_shape=(),
            # Misc
            activation_fn=None,
            name=None):
        """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e.`, `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      init_kernel_fn: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      init_bias_fn: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      batch_shape: ...
        Default value: `()`.
      activation_fn: ...
        Default value: `None`.
      name: ...
        Default value: `None` (i.e., `'Convolution'`).
    """
        filter_shape = prepare_tuple_argument(filter_shape,
                                              rank,
                                              arg_name='filter_shape')
        batch_shape = (np.array([], dtype=np.int32) if batch_shape is None else
                       prefer_static.reshape(batch_shape, shape=[-1]))
        batch_ndims = prefer_static.size(batch_shape)
        if tf.get_static_value(batch_ndims) == 0:
            # In this branch, we statically know there are no batch dims.
            kernel_shape = filter_shape + (input_size, output_size)
            bias_shape = [output_size]
            apply_kernel_fn = _make_convolution_fn(rank, strides, padding,
                                                   dilations)
        else:
            # In this branch, there are either static/dynamic batch dims or
            # dynamically no batch dims.
            kernel_shape = prefer_static.concat(
                [batch_shape, filter_shape, [input_size, output_size]], axis=0)
            bias_shape = prefer_static.concat([batch_shape, [output_size]],
                                              axis=0)
            apply_kernel_fn = lambda x, k: convolution_batch(  # pylint: disable=g-long-lambda
                x,
                k,
                rank=rank,
                strides=strides,
                padding=padding,
                data_format='NHWBC',
                dilations=dilations)
        kernel, bias = make_kernel_bias_fn(kernel_shape, bias_shape,
                                           init_kernel_fn, init_bias_fn,
                                           batch_ndims, batch_ndims, dtype)
        self._make_kernel_bias_fn = make_kernel_bias_fn  # For tracking.
        super(Convolution, self).__init__(kernel=kernel,
                                          bias=bias,
                                          apply_kernel_fn=apply_kernel_fn,
                                          dtype=dtype,
                                          activation_fn=activation_fn,
                                          name=name)
Beispiel #7
0
 def reshape_sample_shape(t):
     batch_event_shape = ps.shape(t)[1:]
     final_shape = ps.concat([sample_shape, batch_event_shape], 0)
     return tf.reshape(t, final_shape)
Beispiel #8
0
    def _sample_n(self, n, seed=None, conditional_input=None, training=False):
        """Samples from the distribution, with optional conditional input.
        Args:
          n: `int`, number of samples desired.
          seed: `int`, seed for RNG. Setting a random seed enforces reproducability
            of the samples between sessions (not within a single session).
          conditional_input: `Tensor` on which to condition the distribution (e.g.
            class labels), or `None`.
          training: `bool` or `None`. If `bool`, it controls the dropout layer,
            where `True` implies dropout is active. If `None`, it defers to Keras'
            handling of train/eval status.
        Returns:
          samples: a `Tensor` of shape `[n, height, width, num_channels]`.
        """
        if conditional_input is not None:
            conditional_input = tf.convert_to_tensor(conditional_input,
                                                     dtype=self.dtype)
            conditional_event_rank = tensorshape_util.rank(
                self.conditional_shape)
            conditional_input_shape = prefer_static.shape(conditional_input)
            conditional_sample_rank = prefer_static.rank(
                conditional_input) - conditional_event_rank

            # If `conditional_input` has no sample dimensions, prepend a sample
            # dimension
            if conditional_sample_rank == 0:
                conditional_input = conditional_input[tf.newaxis, ...]
                conditional_sample_rank = 1

            # Assert that the conditional event shape in the `PixelCnnNetwork` is the
            # same as that implied by `conditional_input`.
            conditional_event_shape = conditional_input_shape[
                conditional_sample_rank:]
            with tf.control_dependencies([
                    tf.assert_equal(self.conditional_shape,
                                    conditional_event_shape)
            ]):
                conditional_sample_shape = conditional_input_shape[:
                                                                   conditional_sample_rank]
                repeat = n // prefer_static.reduce_prod(
                    conditional_sample_shape)
                h = tf.reshape(
                    conditional_input,
                    prefer_static.concat([(-1, ), self.conditional_shape],
                                         axis=0))
                h = tf.tile(
                    h,
                    prefer_static.pad([repeat],
                                      paddings=[[0, conditional_event_rank]],
                                      constant_values=1))

        samples_0 = tf.random.uniform(prefer_static.concat(
            [(n, ), self.event_shape], axis=0),
                                      minval=-1.,
                                      maxval=1.,
                                      dtype=self.dtype,
                                      seed=seed)
        inputs = samples_0 if conditional_input is None else [samples_0, h]
        params_0 = self.network(inputs, training=training)
        samples_0 = self._sample_channels(*params_0, seed=seed)

        image_height, image_width, _ = tensorshape_util.as_list(
            self.event_shape)

        def loop_body(index, samples):
            """Loop for iterative pixel sampling.
            Args:
            index: 0D `Tensor` of type `int32`. Index of the current pixel.
            samples: 4D `Tensor`. Images with pixels sampled in raster order, up to
              pixel `[index]`, with dimensions `[batch_size, height, width,
              num_channels]`.
            Returns:
            samples: 4D `Tensor`. Images with pixels sampled in raster order, up to
              and including pixel `[index]`, with dimensions `[batch_size, height,
              width, num_channels]`.
            """
            inputs = samples if conditional_input is None else [samples, h]
            params = self.network(inputs, training=training)
            samples_new = self._sample_channels(*params, seed=seed)

            # Update the current pixel
            samples = tf.transpose(samples, [1, 2, 3, 0])
            samples_new = tf.transpose(samples_new, [1, 2, 3, 0])
            row, col = index // image_width, index % image_width
            updates = samples_new[row, col, ...][tf.newaxis, ...]
            samples = tf.tensor_scatter_nd_update(samples, [[row, col]],
                                                  updates)
            samples = tf.transpose(samples, [3, 0, 1, 2])

            return index + 1, samples

        index0 = tf.zeros([], dtype=tf.int32)

        # Construct the while loop for sampling
        total_pixels = image_height * image_width
        loop_cond = lambda ind, _: tf.less(ind, total_pixels)  # noqa: E731
        init_vars = (index0, samples_0)
        _, samples = tf.while_loop(loop_cond,
                                   loop_body,
                                   init_vars,
                                   parallel_iterations=1)

        transformed_samples = (self._low + 0.5 * (self._high - self._low) *
                               (samples + 1.))
        return tf.round(transformed_samples)
Beispiel #9
0
 def _event_shape_tensor(self):
   return prefer_static.concat([
       self.sample_shape,
       self.distribution.event_shape_tensor(),
   ], axis=0)
 def _sample_n(self, n, seed=None):
   logits = self._logits_parameter_no_checks()
   new_shape = ps.concat([[n], ps.shape(logits)], axis=0)
   uniform = samplers.uniform(new_shape, seed=seed, dtype=logits.dtype)
   sample = self._quantile(uniform, logits)
   return tf.cast(sample, self.dtype)
  def _variance(self):
    num_states = self.transition_distribution.batch_shape_tensor()[-1]
    batch_shape = self.batch_shape_tensor()
    probs = self._marginal_hidden_probs()
    # probs :: num_steps batch_shape num_states
    observation_distribution = self.observation_distribution
    means = observation_distribution.mean()
    # means :: observation_batch_shape[:-2] num_steps num_states
    #          observation_event_shape
    # or
    # means :: observation_batch_shape[:-1] num_states
    #                                       observation_event_shape
    # the latter case hapens for static observations distributions and we need
    # to add in a steps dimension.
    if not self._time_varying_observation_distribution:
      means = tf.expand_dims(means, ps.rank(batch_shape) - 1)
    means_shape = ps.concat(
        [batch_shape,
         [self._num_steps, num_states],
         observation_distribution.event_shape_tensor()],
        axis=0)
    means = tf.broadcast_to(means, means_shape)
    # means :: batch_shape num_steps num_states observation_event_shape
    observation_event_shape = (
        observation_distribution.event_shape_tensor())
    batch_size = tf.reduce_prod(batch_shape)
    flat_probs_shape = [self._num_steps, batch_size, num_states]
    flat_means_shape = [
        batch_size, self._num_steps, num_states,
        tf.reduce_prod(observation_event_shape)
    ]

    flat_probs = tf.reshape(probs, flat_probs_shape)
    # flat_probs :: num_steps batch_size num_states
    flat_means = tf.reshape(means, flat_means_shape)
    # flat_means :: batch_size num_steps num_states observation_event_size
    flat_mean = tf.einsum('ijk,jikl->jil', flat_probs, flat_means)
    flat_mean = tf.expand_dims(flat_mean, 2)
    # flat_mean :: batch_size num_steps 1 observation_event_size

    variances = observation_distribution.variance()
    if not self._time_varying_observation_distribution:
      variances = tf.expand_dims(variances, tf.rank(batch_shape) - 1)
    variances = tf.broadcast_to(variances, means_shape)
    # variances :: batch_shape num_steps num_states observation_event_shape
    flat_variances = tf.reshape(variances, flat_means_shape)
    # flat_variances :: batch_size num_steps num_states observation_event_size

    # For a mixture of n distributions with mixture probabilities
    # p[i], and where the individual distributions have means and
    # variances given by mean[i] and var[i], the variance of
    # the mixture is given by:
    #
    # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2)

    flat_variance = tf.einsum('ijk,jikl->jil',
                              flat_probs,
                              (flat_means - flat_mean)**2 + flat_variances)
    # flat_variance :: batch_size num_steps observation_event_size
    unflat_mean_shape = ps.concat(
        [batch_shape,
         [self._num_steps],
         observation_event_shape],
        axis=0)

    # returns :: batch_shape num_steps observation_event_shape
    return tf.reshape(flat_variance, unflat_mean_shape)
  def _log_prob(self, value):
    # The argument `value` is a tensor of sequences of observations.
    # `observation_batch_shape` is the shape of that tensor with the
    # sequence part removed.
    # `observation_batch_shape` is then broadcast to the full batch shape
    # to give the `batch_shape` that defines the shape of the result.
    observation_tensor_shape = ps.shape(value)
    observation_distribution = self.observation_distribution
    underlying_event_rank = ps.size(
        observation_distribution.event_shape_tensor())
    observation_batch_shape = observation_tensor_shape[
        :-1 - underlying_event_rank]
    # value :: observation_batch_shape num_steps observation_event_shape
    batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                             self.batch_shape_tensor())
    num_states = self.transition_distribution.batch_shape_tensor()[-1]
    log_init = _extract_log_probs(num_states,
                                  self.initial_distribution)
    # log_init :: batch_shape num_states
    log_init = tf.broadcast_to(log_init,
                               ps.concat([batch_shape,
                                          [num_states]], axis=0))
    log_transition = _extract_log_probs(num_states,
                                        self.transition_distribution)

    # `observation_event_shape` is the shape of each sequence of observations
    # emitted by the model.
    observation_event_shape = observation_tensor_shape[
        -1 - underlying_event_rank:]
    working_obs = tf.broadcast_to(value,
                                  ps.concat([batch_shape,
                                             observation_event_shape],
                                            axis=0))
    # working_obs :: batch_shape observation_event_shape
    r = underlying_event_rank

    # Move index into sequence of observations to front so we can apply
    # tf.foldl
    if self._time_varying_observation_distribution:
      working_obs = tf.expand_dims(working_obs, -1 - r)
      # working_obs :: batch_shape num_steps 1 underlying_event_shape
      observation_probs = observation_distribution.log_prob(working_obs)
      # observation_probs :: batch_shape num_steps num_states
      observation_probs = distribution_util.move_dimension(
          observation_probs, -2, 0)
      # observation_probs :: num_steps batch_shape num_states
    else:
      working_obs = distribution_util.move_dimension(working_obs, -1 - r, 0)
      # working_obs :: num_steps batch_shape underlying_event_shape
      working_obs = tf.expand_dims(working_obs, -1 - r)
      # working_obs :: num_steps batch_shape 1 underlying_event_shape

      observation_probs = observation_distribution.log_prob(working_obs)
      # observation_probs :: num_steps batch_shape num_states

    def forward_step(log_prev_step, log_prob_observation):
      return _log_vector_matrix(log_prev_step,
                                log_transition) + log_prob_observation

    # TODO(davmre): Delete this warning after Dec 31, 2020.
    warnings.warn(
        'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
        'in which the transition model was applied prior to the initial step. '
        'This bug has been fixed. You may observe a slight change in behavior.')
    fwd_prob = tf.foldl(forward_step, observation_probs[1:],
                        initializer=log_init + observation_probs[0])
    # fwd_prob :: batch_shape num_states

    log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
    # log_prob :: batch_shape

    return log_prob
  def _sample_n(self, n, seed=None):
    init_seed, scan_seed, observation_seed = samplers.split_seed(
        seed, n=3, salt='HiddenMarkovModel')

    transition_batch_shape = self.transition_distribution.batch_shape_tensor()
    num_states = transition_batch_shape[-1]

    batch_shape = self.batch_shape_tensor()
    batch_size = ps.reduce_prod(batch_shape)
    # The batch sizes of the underlying initial distributions and
    # transition distributions might not match the batch size of
    # the HMM distribution.
    # As a result we need to ask for more samples from the
    # underlying distributions and then reshape the results into
    # the correct batch size for the HMM.
    init_repeat = (
        ps.reduce_prod(batch_shape) //
        ps.reduce_prod(self._initial_distribution.batch_shape_tensor()))
    init_state = self._initial_distribution.sample(n * init_repeat,
                                                   seed=init_seed)
    init_state = tf.reshape(init_state, [n, batch_size])
    # init_state :: n batch_size

    transition_repeat = (
        ps.reduce_prod(batch_shape) // ps.reduce_prod(
            transition_batch_shape[:-1]))

    init_shape = init_state.shape

    def generate_step(state_and_seed, _):
      """Take a single step in Markov chain."""
      state, seed = state_and_seed
      sample_seed, next_seed = samplers.split_seed(seed)

      gen = self._transition_distribution.sample(n * transition_repeat,
                                                 seed=sample_seed)
      # gen :: (n * transition_repeat) transition_batch

      new_states = tf.reshape(gen,
                              [n, batch_size, num_states])

      # new_states :: n batch_size num_states

      old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32)

      # old_states :: n batch_size num_states

      result = tf.reduce_sum(old_states_one_hot * new_states, axis=-1)
      # We know that `generate_step` must preserve the shape of the
      # tensor of states of each state. This is because
      # the transition matrix must be square. But TensorFlow might
      # not know this so we explicitly tell it that the result has the
      # same shape.
      tensorshape_util.set_shape(result, init_shape)
      return result, next_seed

    def _scan_multiple_steps():
      """Take multiple steps with tf.scan."""
      dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
      hidden_states, _ = tf.scan(generate_step, dummy_index,
                                 initializer=(init_state, scan_seed))

      # TODO(b/115618503): add/use prepend_initializer to tf.scan
      return tf.concat([[init_state],
                        hidden_states], axis=0)
    hidden_states = ps.cond(
        self._num_steps > 1,
        _scan_multiple_steps,
        lambda: init_state[tf.newaxis, ...])

    hidden_one_hot = tf.one_hot(hidden_states, num_states,
                                dtype=self._observation_distribution.dtype)
    # hidden_one_hot :: num_steps n batch_size num_states

    # The observation distribution batch size might not match
    # the required batch size so as with the initial and
    # transition distributions we generate more samples and
    # reshape.
    observation_repeat = tf.maximum(
        batch_size // ps.reduce_prod(
            self._observation_distribution.batch_shape_tensor()[:-1]),
        1)

    if self._time_varying_observation_distribution:
      possible_observations = self._observation_distribution.sample(
          [observation_repeat * n], seed=observation_seed)
      # possible observations needs to have num_steps moved to the beginning.
      possible_observations = distribution_util.move_dimension(
          possible_observations,
          -(tf.size(self._observation_distribution.event_shape_tensor()) + 2),
          0)
    else:
      possible_observations = self._observation_distribution.sample(
          [self._num_steps, observation_repeat * n], seed=observation_seed)

    inner_shape = self._observation_distribution.event_shape_tensor()

    # possible_observations :: num_steps (observation_repeat * n)
    #                          observation_batch[:-1] num_states inner_shape

    possible_observations = tf.reshape(
        possible_observations,
        ps.concat([[self._num_steps, n],
                   batch_shape,
                   [num_states],
                   inner_shape], axis=0))

    # possible_observations :: steps n batch_size num_states inner_shape

    hidden_one_hot = tf.reshape(hidden_one_hot,
                                ps.concat([[self._num_steps, n],
                                           batch_shape,
                                           [num_states],
                                           ps.ones_like(inner_shape)],
                                          axis=0))

    # hidden_one_hot :: steps n batch_size num_states "inner_shape"

    observations = tf.reduce_sum(
        hidden_one_hot * possible_observations,
        axis=-1 - ps.size(inner_shape))
    # observations :: steps n batch_size inner_shape

    observations = distribution_util.move_dimension(observations, 0,
                                                    1 + ps.size(batch_shape))
    # returned :: n batch_shape steps inner_shape

    return observations
 def _event_shape_tensor(self):
   return ps.concat([[self._num_steps],
                     self.observation_distribution.event_shape_tensor()],
                    axis=0)
Beispiel #15
0
 def _reshape_part(part, dtype, event_shape):
     part = tf.cast(part, dtype)
     new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1)
     return tf.reshape(part, ps.cast(new_shape, tf.int32))
Beispiel #16
0
def resample_minimum_variance(log_probs,
                              event_size,
                              sample_shape,
                              seed=None,
                              name=None):
    """Minimum variance resampler for sequential Monte Carlo.

  This function is based on Algorithm #2 in [Maskell et al. (2006)][1].

  Args:
    log_probs: A tensor-valued batch of discrete log probability distributions.
    event_size: the dimension of the vector considered a single draw.
    sample_shape: the `sample_shape` determining the number of draws.
    seed: Python '`int` used to seed calls to `tf.random.*`.
      Default value: None (i.e. no seed).
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'resample_minimum_variance'`).

  Returns:
    resampled_indices: The result is similar to sampling with
    ```python
    expanded_sample_shape = tf.concat([[event_size], sample_shape]), axis=-1)
    tfd.Categorical(logits=log_probs).sample(expanded_sample_shape)`
    ```
    but with values sorted along the first axis. It can be considered to be
    sampling events made up of a length-`event_size` vector of draws from
    the `Categorical` distribution. However, although the elements of
    this event have the appropriate marginal distribution, they are not
    independent of each other. Instead they have been chosen so as to form
    a good representative sample, suitable for use with Sequential Monte
    Carlo algorithms.
    The sortedness is an unintended side effect of the algorithm that is
    harmless in the context of simple SMC algorithms.

  #### References
  [1]: S. Maskell, B. Alun-Jones and M. Macleod. A Single Instruction Multiple
       Data Particle Filter.
       In 2006 IEEE Nonlinear Statistical Signal Processing Workshop.
       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf

  """
    with tf.name_scope(name or 'resample_minimum_variance') as name:
        log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32)
        log_probs = dist_util.move_dimension(log_probs,
                                             source_idx=0,
                                             dest_idx=-1)

        batch_shape = prefer_static.shape(log_probs)[:-1]
        working_shape = prefer_static.concat([sample_shape, batch_shape],
                                             axis=-1)
        log_cdf = tf.math.cumulative_logsumexp(log_probs[..., :-1], axis=-1)
        # Each resampling requires a single uniform random variable
        offset = uniform.Uniform(low=tf.constant(0., log_cdf.dtype),
                                 high=tf.constant(1., log_cdf.dtype)).sample(
                                     working_shape, seed=seed)[..., tf.newaxis]
        # It is possible for numerical error to result in a cumulative
        # sum that exceeds 1 so we need to clip.
        markers = prefer_static.cast(
            tf.floor(event_size * tf.math.exp(log_cdf) + offset), tf.int32)
        indices = markers[..., tf.newaxis]
        updates = tf.ones(prefer_static.shape(indices)[:-1], dtype=tf.int32)
        scatter_shape = prefer_static.concat([working_shape, [event_size + 1]],
                                             axis=-1)
        batch_dims = (prefer_static.rank_from_shape(sample_shape) +
                      prefer_static.rank_from_shape(batch_shape))
        x = _scatter_nd_batch(indices,
                              updates,
                              scatter_shape,
                              batch_dims=batch_dims)

        resampled = tf.cumsum(x, axis=-1)[..., :-1]
        resampled = dist_util.move_dimension(resampled,
                                             source_idx=-1,
                                             dest_idx=0)
        return resampled
Beispiel #17
0
def _potential_scale_reduction_single_state(state, independent_chain_ndims,
                                            split_chains, validate_args):
    """potential_scale_reduction for one single state `Tensor`."""
    with tf.name_scope('potential_scale_reduction_single_state'):
        # We assume exactly one leading dimension indexes e.g. correlated samples
        # from each Markov chain.
        state = tf.convert_to_tensor(state, name='state')

        n_samples_ = tf.compat.dimension_value(state.shape[0])
        if n_samples_ is not None:  # If available statically.
            if split_chains and n_samples_ < 4:
                raise ValueError(
                    'Must provide at least 4 samples when splitting chains. '
                    'Found {}'.format(n_samples_))
            if not split_chains and n_samples_ < 2:
                raise ValueError(
                    'Must provide at least 2 samples.  Found {}'.format(
                        n_samples_))
        elif validate_args:
            if split_chains:
                assertions = [
                    assert_util.assert_greater(
                        tf.shape(state)[0],
                        4,
                        message=
                        'Must provide at least 4 samples when splitting chains.'
                    )
                ]
                with tf.control_dependencies(assertions):
                    state = tf.identity(state)
            else:
                assertions = [
                    assert_util.assert_greater(
                        tf.shape(state)[0],
                        2,
                        message='Must provide at least 2 samples.')
                ]
                with tf.control_dependencies(assertions):
                    state = tf.identity(state)

        # Define so it's not a magic number.
        # Warning!  `if split_chains` logic assumes this is 1!
        sample_ndims = 1

        if split_chains:
            # Split the sample dimension in half, doubling the number of
            # independent chains.

            # For odd number of samples, keep all but the last sample.
            state_shape = prefer_static.shape(state)
            n_samples = state_shape[0]
            state = state[:n_samples - n_samples % 2]

            # Suppose state = [0, 1, 2, 3, 4, 5]
            # Step 1: reshape into [[0, 1, 2], [3, 4, 5]]
            # E.g. reshape states of shape [a, b] into [2, a//2, b].
            state = tf.reshape(
                state,
                prefer_static.concat([[2, n_samples // 2], state_shape[1:]],
                                     axis=0))
            # Step 2: Put the size `2` dimension in the right place to be treated as a
            # chain, changing [[0, 1, 2], [3, 4, 5]] into [[0, 3], [1, 4], [2, 5]],
            # reshaping [2, a//2, b] into [a//2, 2, b].
            state = tf.transpose(
                a=state,
                perm=prefer_static.concat(
                    [[1, 0], tf.range(2, tf.rank(state))], axis=0))

            # We're treating the new dim as indexing 2 chains, so increment.
            independent_chain_ndims += 1

        sample_axis = tf.range(0, sample_ndims)
        chain_axis = tf.range(sample_ndims,
                              sample_ndims + independent_chain_ndims)
        sample_and_chain_axis = tf.range(
            0, sample_ndims + independent_chain_ndims)

        n = _axis_size(state, sample_axis)
        m = _axis_size(state, chain_axis)

        # In the language of Brooks and Gelman (1998),
        # B / n is the between chain variance, the variance of the chain means.
        # W is the within sequence variance, the mean of the chain variances.
        b_div_n = _reduce_variance(tf.reduce_mean(state,
                                                  axis=sample_axis,
                                                  keepdims=True),
                                   sample_and_chain_axis,
                                   biased=False)
        w = tf.reduce_mean(_reduce_variance(state,
                                            sample_axis,
                                            keepdims=True,
                                            biased=True),
                           axis=sample_and_chain_axis)

        # sigma^2_+ is an estimate of the true variance, which would be unbiased if
        # each chain was drawn from the target.  c.f. "law of total variance."
        sigma_2_plus = w + b_div_n

        return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
  def _observation_log_probs(self, observations, mask):
    """Compute and shape tensor of log probs associated with observations.."""

    # Let E be the underlying event shape
    #     M the number of steps in the HMM
    #     N the number of states of the HMM
    #
    # Then the incoming observations have shape
    #
    # observations : batch_o [M] E
    #
    # and the mask (if present) has shape
    #
    # mask : batch_m [M]
    #
    # Let this HMM distribution have batch shape batch_d
    # We need to broadcast all three of these batch shapes together
    # into the shape batch.
    #
    # We need to move the step dimension to the first dimension to make
    # them suitable for folding or scanning over.
    #
    # When we call `log_prob` for our observations we need to
    # do this for each state the observation could correspond to.
    # We do this by expanding the dimensions by 1 so we end up with:
    #
    # observations : [M] batch [1] [E]
    #
    # After calling `log_prob` we get
    #
    # observation_log_probs : [M] batch [N]
    #
    # We wish to use `mask` to select from this so we also
    # reshape and broadcast it up to shape
    #
    # mask : [M] batch [N]

    observation_distribution = self.observation_distribution
    underlying_event_rank = ps.size(
        observation_distribution.event_shape_tensor())
    observation_tensor_shape = ps.shape(observations)
    observation_batch_shape = observation_tensor_shape[
        :-1 - underlying_event_rank]
    observation_event_shape = observation_tensor_shape[
        -1 - underlying_event_rank:]

    if mask is not None:
      mask_tensor_shape = ps.shape(mask)
      mask_batch_shape = mask_tensor_shape[:-1]

    batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                             self.batch_shape_tensor())

    if mask is not None:
      batch_shape = tf.broadcast_dynamic_shape(batch_shape,
                                               mask_batch_shape)
    observations = tf.broadcast_to(observations,
                                   ps.concat([batch_shape,
                                              observation_event_shape],
                                             axis=0))
    observation_rank = ps.rank(observations)
    observations = distribution_util.move_dimension(
        observations, observation_rank - underlying_event_rank - 1, 0)
    observations = tf.expand_dims(
        observations,
        observation_rank - underlying_event_rank)
    observation_log_probs = observation_distribution.log_prob(
        observations)

    if mask is not None:
      mask = tf.broadcast_to(mask,
                             ps.concat([batch_shape, [self._num_steps]],
                                       axis=0))
      mask = distribution_util.move_dimension(mask, -1, 0)
      observation_log_probs = tf.where(mask[..., tf.newaxis],
                                       tf.zeros_like(observation_log_probs),
                                       observation_log_probs)

    return observation_log_probs
Beispiel #19
0
def covariance(x,
               y=None,
               sample_axis=0,
               event_axis=-1,
               keepdims=False,
               name=None):
    """Sample covariance between observations indexed by `event_axis`.

  Given `N` samples of scalar random variables `X` and `Y`, covariance may be
  estimated as

  ```none
  Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)}
  Xbar := N^{-1} sum_{n=1}^N X_n
  Ybar := N^{-1} sum_{n=1}^N Y_n
  ```

  For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`,
  one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`.

  ```python
  x = tf.random.normal(shape=(100, 2, 3))
  y = tf.random.normal(shape=(100, 2, 3))

  # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j].
  cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None)

  # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n]
  cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1)
  ```

  Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is
  slightly biased.

  Args:
    x:  A numeric `Tensor` holding samples.
    y:  Optional `Tensor` with same `dtype` and `shape` as `x`.
      Default value: `None` (`y` is effectively set to `x`).
    sample_axis: Scalar or vector `Tensor` designating axis holding samples, or
      `None` (meaning all axis hold samples).
      Default value: `0` (leftmost dimension).
    event_axis:  Scalar or vector `Tensor`, or `None` (scalar events).
      Axis indexing random events, whose covariance we are interested in.
      If a vector, entries must form a contiguous block of dims. `sample_axis`
      and `event_axis` should not intersect.
      Default value: `-1` (rightmost axis holds events).
    keepdims:  Boolean.  Whether to keep the sample axis as singletons.
    name: Python `str` name prefixed to Ops created by this function.
          Default value: `None` (i.e., `'covariance'`).

  Returns:
    cov: A `Tensor` of same `dtype` as the `x`, and rank equal to
      `rank(x) - len(sample_axis) + 2 * len(event_axis)`.

  Raises:
    AssertionError:  If `x` and `y` are found to have different shape.
    ValueError:  If `sample_axis` and `event_axis` are found to overlap.
    ValueError:  If `event_axis` is found to not be contiguous.
  """

    with tf.name_scope(name or 'covariance'):
        x = tf.convert_to_tensor(x, name='x')
        # Covariance *only* uses the centered versions of x (and y).
        x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True)

        if y is None:
            y = x
        else:
            y = tf.convert_to_tensor(y, name='y', dtype=x.dtype)
            # If x and y have different shape, sample_axis and event_axis will likely
            # be wrong for one of them!
            tensorshape_util.assert_is_compatible_with(x.shape, y.shape)
            y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True)

        if event_axis is None:
            return tf.reduce_mean(x * tf.math.conj(y),
                                  axis=sample_axis,
                                  keepdims=keepdims)

        if sample_axis is None:
            raise ValueError(
                'sample_axis was None, which means all axis hold events, and this '
                'overlaps with event_axis ({})'.format(event_axis))

        event_axis = _make_positive_axis(event_axis, ps.rank(x))
        sample_axis = _make_positive_axis(sample_axis, ps.rank(x))

        # If we get lucky and axis is statically defined, we can do some checks.
        if _is_list_like(event_axis) and _is_list_like(sample_axis):
            event_axis = tuple(map(int, event_axis))
            sample_axis = tuple(map(int, sample_axis))
            if set(event_axis).intersection(sample_axis):
                raise ValueError(
                    'sample_axis ({}) and event_axis ({}) overlapped'.format(
                        sample_axis, event_axis))
            if (np.diff(np.array(sorted(event_axis))) > 1).any():
                raise ValueError(
                    'event_axis must be contiguous. Found: {}'.format(
                        event_axis))
            batch_axis = list(
                sorted(
                    set(range(tensorshape_util.rank(
                        x.shape))).difference(sample_axis + event_axis)))
        else:
            batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)),
                                      ps.concat((sample_axis, event_axis), 0))

        event_axis = ps.cast(event_axis, dtype=tf.int32)
        sample_axis = ps.cast(sample_axis, dtype=tf.int32)
        batch_axis = ps.cast(batch_axis, dtype=tf.int32)

        # Permute x/y until shape = B + E + S
        perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0)
        x_permed = tf.transpose(a=x, perm=perm_for_xy)
        y_permed = tf.transpose(a=y, perm=perm_for_xy)

        batch_ndims = ps.size(batch_axis)
        batch_shape = ps.shape(x_permed)[:batch_ndims]
        event_ndims = ps.size(event_axis)
        event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims]
        sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:]
        sample_ndims = ps.size(sample_shape)
        n_samples = ps.reduce_prod(sample_shape)
        n_events = ps.reduce_prod(event_shape)

        # Flatten sample_axis into one long dim.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        # Do the same for event_axis.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))

        # After matmul, cov.shape = batch_shape + [n_events, n_events]
        cov = tf.matmul(x_permed_flat, y_permed_flat,
                        adjoint_b=True) / ps.cast(n_samples, x.dtype)

        # Insert some singletons to make
        # cov.shape = batch_shape + event_shape**2 + [1,...,1]
        # This is just like x_permed.shape, except the sample_axis is all 1's, and
        # the [n_events] became event_shape**2.
        cov = tf.reshape(
            cov,
            ps.concat(
                (
                    batch_shape,
                    # event_shape**2 used here because it is the same length as
                    # event_shape, and has the same number of elements as one
                    # batch of covariance.
                    event_shape**2,
                    ps.ones([sample_ndims], tf.int32)),
                0))
        # Permuting by the argsort inverts the permutation, making
        # cov.shape have ones in the position where there were samples, and
        # [n_events * n_events] in the event position.
        cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy))

        # Now expand event_shape**2 into event_shape + event_shape.
        # We here use (for the first time) the fact that we require event_axis to be
        # contiguous.
        e_start = event_axis[0]
        e_len = 1 + event_axis[-1] - event_axis[0]
        cov = tf.reshape(
            cov,
            ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape,
                       ps.shape(cov)[e_start + e_len:]), 0))

        # tf.squeeze requires python ints for axis, not Tensor.  This is enough to
        # require our axis args to be constants.
        if not keepdims:
            squeeze_axis = ps.where(sample_axis < e_start, sample_axis,
                                    sample_axis + e_len)
            cov = _squeeze(cov, axis=squeeze_axis)

        return cov
  def __init__(
      self,
      input_size,
      output_size,          # keras::Conv::filters
      # Conv specific.
      filter_shape,         # keras::Conv::kernel_size
      rank=2,               # keras::Conv::rank
      strides=1,            # keras::Conv::strides
      padding='VALID',      # keras::Conv::padding; 'CAUSAL' not implemented.
                            # keras::Conv::data_format is not implemented
      dilations=1,          # keras::Conv::dilation_rate
      # Weights
      kernel_initializer=None,  # tfp.nn.initializers.glorot_uniform()
      bias_initializer=None,    # tf.initializers.zeros()
      make_kernel_bias_fn=kernel_bias_lib.make_kernel_bias,
      dtype=tf.float32,
      index_dtype=tf.int32,
      batch_shape=(),
      # Misc
      activation_fn=None,
      validate_args=False,
      name=None):
    """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e., `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      kernel_initializer: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      bias_initializer: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      index_dtype: ...
      batch_shape: ...
        Default value: `()`.
      activation_fn: ...
        Default value: `None`.
      validate_args: ...
      name: ...
        Default value: `None` (i.e., `'ConvolutionV2'`).
    """
    filter_shape = convolution_util.prepare_tuple_argument(
        filter_shape, rank, arg_name='filter_shape',
        validate_args=validate_args)
    batch_shape = (tf.constant([], dtype=tf.int32) if batch_shape is None
                   else ps.cast(ps.reshape(batch_shape, shape=[-1]), tf.int32))
    batch_ndims = ps.size(batch_shape)

    apply_kernel_fn = convolution_util.make_convolution_fn(
        filter_shape, rank=2, strides=strides, padding=padding,
        dilations=dilations, dtype=index_dtype, validate_args=validate_args)

    kernel_shape = ps.concat(
        [batch_shape, [ps.reduce_prod(filter_shape) * input_size, output_size]],
        axis=0)
    bias_shape = ps.concat(
        [batch_shape, tf.ones(rank), [output_size]], axis=0)
    kernel, bias = make_kernel_bias_fn(
        kernel_shape, bias_shape,
        kernel_initializer, bias_initializer,
        batch_ndims, batch_ndims,
        dtype)

    self._make_kernel_bias_fn = make_kernel_bias_fn  # For tracking.
    super(ConvolutionV2, self).__init__(
        kernel=kernel,
        bias=bias,
        apply_kernel_fn=apply_kernel_fn,
        dtype=dtype,
        activation_fn=activation_fn,
        validate_args=validate_args,
        name=name)
Beispiel #21
0
    def _log_prob(self,
                  value,
                  conditional_input=None,
                  training=None,
                  return_per_feature=False):
        """Log probability function with optional conditional input.
        Calculates the log probability of a batch of data under the modeled
        distribution (or conditional distribution, if conditional input is
        provided).
        Args:
          value: `Tensor` or Numpy array of image data. May have leading batch
            dimension(s), which must broadcast to the leading batch dimensions of
            `conditional_input`.
          conditional_input: `Tensor` on which to condition the distribution (e.g.
            class labels), or `None`. May have leading batch dimension(s), which
            must broadcast to the leading batch dimensions of `value`.
          training: `bool` or `None`. If `bool`, it controls the dropout layer,
            where `True` implies dropout is active. If `None`, it defaults to
            `tf.keras.backend.learning_phase()`.
          return_per_feature: `bool`. If True, return per pixel level log prob.
        Returns:
          log_prob_values: `Tensor`.
        """
        # Determine the batch shape of the input images
        image_batch_shape = prefer_static.shape(value)[:-3]

        # Broadcast `value` and `conditional_input` to the same batch_shape
        if conditional_input is None:
            image_batch_and_conditional_shape = image_batch_shape
        else:
            conditional_input = tf.convert_to_tensor(conditional_input)
            conditional_input_shape = prefer_static.shape(conditional_input)
            conditional_batch_rank = (
                prefer_static.rank(conditional_input) -
                tensorshape_util.rank(self.conditional_shape))
            conditional_batch_shape = conditional_input_shape[:
                                                              conditional_batch_rank]

            image_batch_and_conditional_shape = prefer_static.broadcast_shape(
                image_batch_shape, conditional_batch_shape)
            conditional_input = tf.broadcast_to(
                conditional_input,
                prefer_static.concat([
                    image_batch_and_conditional_shape, self.conditional_shape
                ],
                                     axis=0))
            value = tf.broadcast_to(
                value,
                prefer_static.concat(
                    [image_batch_and_conditional_shape, self.event_shape],
                    axis=0))

            # Flatten batch dimension for input to Keras model
            conditional_input = tf.reshape(
                conditional_input,
                prefer_static.concat([(-1, ), self.conditional_shape], axis=0))

        value = tf.reshape(
            value, prefer_static.concat([(-1, ), self.event_shape], axis=0))

        transformed_value = (2. * (value - self._low) /
                             (self._high - self._low)) - 1.
        inputs = transformed_value if conditional_input is None else [
            transformed_value, conditional_input
        ]

        params = self.network(inputs, training=training)

        num_channels = self.event_shape[-1]
        if num_channels == 1:
            component_logits, locs, scales = params
        else:
            # If there is more than one channel, we create a linear autoregressive
            # dependency among the location parameters of the channels of a single
            # pixel (the scale parameters within a pixel are independent). For a pixel
            # with R/G/B channels, the `r`, `g`, and `b` saturation values are
            # distributed as:
            #
            # r ~ Logistic(loc_r, scale_r)
            # g ~ Logistic(coef_rg * r + loc_g, scale_g)
            # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b)
            # on the coefficients instead of split/multiply/concat
            component_logits, locs, scales, coeffs = params
            num_coeffs = num_channels * (num_channels - 1) // 2
            loc_tensors = tf.split(locs, num_channels, axis=-1)
            coef_tensors = tf.split(coeffs, num_coeffs, axis=-1)
            channel_tensors = tf.split(value, num_channels, axis=-1)

            coef_count = 0
            for i in range(num_channels):
                channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :]
                for j in range(i):
                    loc_tensors[
                        i] += channel_tensors[j] * coef_tensors[coef_count]
                    coef_count += 1
            locs = tf.concat(loc_tensors, axis=-1)

        dist = self._make_mixture_dist(component_logits,
                                       locs,
                                       scales,
                                       return_per_feature=return_per_feature)
        log_px = dist.log_prob(value)
        if return_per_feature:
            return log_px
        else:
            return tf.reshape(log_px, image_batch_and_conditional_shape)
def _vec(x):
    # Vec takes in a (batch) of matrices of shape B1 + [n, k] and returns
    # a (batch) of vectors of shape B1 + [n * k].
    return tf.reshape(x, ps.concat([ps.shape(x)[:-2], [-1]], axis=0))
Beispiel #23
0
    def build(self, input_shape):
        dtype = self.dtype
        if len(input_shape) == 2:
            batch_image_shape, batch_conditional_shape = input_shape
            conditional_input = tf.keras.layers.Input(
                shape=batch_conditional_shape[1:], dtype=dtype)
        else:
            batch_image_shape = input_shape
            conditional_input = None

        image_shape = batch_image_shape[1:]
        image_input = tf.keras.layers.Input(shape=image_shape, dtype=dtype)

        if self._resnet_activation == 'concat_elu':
            activation = tf.keras.layers.Lambda(
                lambda x: tf.nn.elu(tf.concat([x, -x], axis=-1)), dtype=dtype)
        else:
            activation = tf.keras.activations.get(self._resnet_activation)

        # Define layers with default inputs and layer wrapper applied
        Conv2D = functools.partial(  # pylint:disable=invalid-name
            self._layer_wrapper(tf.keras.layers.Convolution2D),
            filters=self._num_filters,
            padding='same',
            kernel_regularizer=tf.keras.regularizers.l2(self._l2_weight),
            dtype=dtype)

        Dense = functools.partial(  # pylint:disable=invalid-name
            self._layer_wrapper(tf.keras.layers.Dense),
            kernel_regularizer=tf.keras.regularizers.l2(self._l2_weight),
            dtype=dtype)

        Conv2DTranspose = functools.partial(  # pylint:disable=invalid-name
            self._layer_wrapper(tf.keras.layers.Conv2DTranspose),
            filters=self._num_filters,
            padding='same',
            strides=(2, 2),
            kernel_regularizer=tf.keras.regularizers.l2(self._l2_weight),
            dtype=dtype)

        rows, cols = self._receptive_field_dims

        # Define the dimensions of the valid (unmasked) areas of the layer kernels
        # for stride 1 convolutions in the internal layers.
        kernel_valid_dims = {
            'vertical': (rows - 1, cols),  # vertical stack
            'horizontal': (2, cols // 2 + 1)
        }  # horizontal stack

        # Define the size of the kernel necessary to center the current pixel
        # correctly for stride 1 convolutions in the internal layers.
        kernel_sizes = {
            'vertical': (2 * rows - 3, cols),
            'horizontal': (3, cols)
        }

        # Make the kernel constraint functions for stride 1 convolutions in internal
        # layers.
        kernel_constraints = {
            k: _make_kernel_constraint(kernel_sizes[k], (0, v[0]), (0, v[1]))
            for k, v in kernel_valid_dims.items()
        }

        # Build the initial vertical stack/horizontal stack convolutional layers,
        # as shown in Figure 1 of [2]. The receptive field of the initial vertical
        # stack layer is a rectangular area centered above the current pixel.
        vertical_stack_init = Conv2D(kernel_size=(2 * rows - 1, cols),
                                     kernel_constraint=_make_kernel_constraint(
                                         (2 * rows - 1, cols), (0, rows - 1),
                                         (0, cols)))(image_input)

        # In Figure 1 [2], the receptive field of the horizontal stack is
        # illustrated as the pixels in the same row and to the left of the current
        # pixel. [1] increases the height of this receptive field from one pixel to
        # two (`horizontal_stack_left`) and additionally includes a subset of the
        # row of pixels centered above the current pixel (`horizontal_stack_up`).
        horizontal_stack_up = Conv2D(kernel_size=(3, cols),
                                     kernel_constraint=_make_kernel_constraint(
                                         (3, cols), (0, 1),
                                         (0, cols)))(image_input)

        horizontal_stack_left = Conv2D(
            kernel_size=(3, cols),
            kernel_constraint=_make_kernel_constraint(
                (3, cols), (0, 2), (0, cols // 2)))(image_input)

        horizontal_stack_init = tf.keras.layers.add(
            [horizontal_stack_up, horizontal_stack_left], dtype=dtype)

        layer_stacks = {
            'vertical': [vertical_stack_init],
            'horizontal': [horizontal_stack_init]
        }

        # Build the downward pass of the U-net (left-hand half of Figure 2 of [1]).
        # Each `i` iteration builds one of the highest-level blocks (identified as
        # 'Sequence of 6 layers' in the figure, consisting of `num_resnet=5` stride-
        # 1 layers, and one stride-2 layer that contracts the height/width
        # dimensions). The `_` iterations build the stride 1 layers. The layers of
        # the downward pass are stored in lists, since we'll later need them to make
        # skip-connections to layers in the upward pass of the U-net (the skip-
        # connections are represented by curved lines in Figure 2 [1]).
        for i in range(self._num_hierarchies):
            for _ in range(self._num_resnet):
                # Build a layer shown in Figure 2 of [2]. The 'vertical' iteration
                # builds the layers in the left half of the figure, and the 'horizontal'
                # iteration builds the layers in the right half.
                for stack in ['vertical', 'horizontal']:
                    input_x = layer_stacks[stack][-1]
                    x = activation(input_x)
                    x = Conv2D(kernel_size=kernel_sizes[stack],
                               kernel_constraint=kernel_constraints[stack])(x)

                    # Add the vertical-stack layer to the horizontal-stack layer
                    if stack == 'horizontal':
                        h = activation(layer_stacks['vertical'][-1])
                        h = Dense(self._num_filters)(h)
                        x = tf.keras.layers.add([h, x], dtype=dtype)

                    x = activation(x)
                    x = tf.keras.layers.Dropout(self._dropout_p,
                                                dtype=dtype)(x)
                    x = Conv2D(filters=2 * self._num_filters,
                               kernel_size=kernel_sizes[stack],
                               kernel_constraint=kernel_constraints[stack])(x)

                    if conditional_input is not None:
                        h_projection = _build_and_apply_h_projection(
                            conditional_input, self._num_filters, dtype=dtype)
                        x = tf.keras.layers.add([x, h_projection], dtype=dtype)

                    x = _apply_sigmoid_gating(x)

                    # Add a residual connection from the layer's input.
                    out = tf.keras.layers.add([input_x, x], dtype=dtype)
                    layer_stacks[stack].append(out)

            if i < self._num_hierarchies - 1:
                # Build convolutional layers that contract the height/width dimensions
                # on the downward pass between each set of layers (e.g. contracting from
                # 32x32 to 16x16 in Figure 2 of [1]).
                for stack in ['vertical', 'horizontal']:
                    # Define kernel dimensions/masking to maintain the autoregressive property.
                    x = layer_stacks[stack][-1]
                    h, w = kernel_valid_dims[stack]
                    kernel_height = 2 * h
                    if stack == 'vertical':
                        kernel_width = w + 1
                    else:
                        kernel_width = 2 * w

                    kernel_size = (kernel_height, kernel_width)
                    kernel_constraint = _make_kernel_constraint(
                        kernel_size, (0, h), (0, w))
                    x = Conv2D(strides=(2, 2),
                               kernel_size=kernel_size,
                               kernel_constraint=kernel_constraint)(x)
                    layer_stacks[stack].append(x)

        # Upward pass of the U-net (right-hand half of Figure 2 of [1]). We stored
        # the layers of the downward pass in a list, in order to access them to make
        # skip-connections to the upward pass. For the upward pass, we need to keep
        # track of only the current layer, so we maintain a reference to the
        # current layer of the horizontal/vertical stack in the `upward_pass` dict.
        # The upward pass begins with the last layer of the downward pass.
        upward_pass = {key: stack.pop() for key, stack in layer_stacks.items()}

        # As with the downward pass, each `i` iteration builds a highest level block
        # in Figure 2 [1], and the `_` iterations build individual layers within the
        # block.
        for i in range(self._num_hierarchies):
            num_resnet = self._num_resnet if i == 0 else self._num_resnet + 1

            for _ in range(num_resnet):
                # Build a layer as shown in Figure 2 of [2], with a skip-connection
                # from the symmetric layer in the downward pass.
                for stack in ['vertical', 'horizontal']:
                    input_x = upward_pass[stack]
                    x_symmetric = layer_stacks[stack].pop()

                    x = activation(input_x)
                    x = Conv2D(kernel_size=kernel_sizes[stack],
                               kernel_constraint=kernel_constraints[stack])(x)

                    # Include the vertical-stack layer of the upward pass in the layers
                    # to be added to the horizontal layer.
                    if stack == 'horizontal':
                        x_symmetric = tf.keras.layers.Concatenate(
                            axis=-1, dtype=dtype)(
                                [upward_pass['vertical'], x_symmetric])

                    # Add a skip-connection from the symmetric layer in the downward
                    # pass to the layer `x` in the upward pass.
                    h = activation(x_symmetric)
                    h = Dense(self._num_filters)(h)
                    x = tf.keras.layers.add([h, x], dtype=dtype)

                    x = activation(x)
                    x = tf.keras.layers.Dropout(self._dropout_p,
                                                dtype=dtype)(x)
                    x = Conv2D(filters=2 * self._num_filters,
                               kernel_size=kernel_sizes[stack],
                               kernel_constraint=kernel_constraints[stack])(x)

                    if conditional_input is not None:
                        h_projection = _build_and_apply_h_projection(
                            conditional_input, self._num_filters, dtype=dtype)
                        x = tf.keras.layers.add([x, h_projection], dtype=dtype)

                    x = _apply_sigmoid_gating(x)
                    upward_pass[stack] = tf.keras.layers.add([input_x, x],
                                                             dtype=dtype)

            # Define deconvolutional layers that expand height/width dimensions on the
            # upward pass (e.g. expanding from 8x8 to 16x16 in Figure 2 of [1]), with
            # the correct kernel dimensions/masking to maintain the autoregressive
            # property.
            if i < self._num_hierarchies - 1:
                for stack in ['vertical', 'horizontal']:
                    h, w = kernel_valid_dims[stack]
                    kernel_height = 2 * h - 2
                    if stack == 'vertical':
                        kernel_width = w + 1
                        kernel_constraint = _make_kernel_constraint(
                            (kernel_height, kernel_width),
                            (h - 2, kernel_height), (0, w))
                    else:
                        kernel_width = 2 * w - 2
                        kernel_constraint = _make_kernel_constraint(
                            (kernel_height, kernel_width),
                            (h - 2, kernel_height), (w - 2, kernel_width))

                    x = upward_pass[stack]
                    x = Conv2DTranspose(kernel_size=(kernel_height,
                                                     kernel_width),
                                        kernel_constraint=kernel_constraint)(x)
                    upward_pass[stack] = x

        x_out = tf.keras.layers.ELU(dtype=dtype)(upward_pass['horizontal'])

        # Build final Dense/Reshape layers to output the correct number of
        # parameters per pixel.
        num_channels = tensorshape_util.as_list(image_shape)[-1]
        num_coeffs = num_channels * (
            num_channels - 1) // 2  # alpha, beta, gamma in eq.3 of paper
        num_out = num_channels * 2 + num_coeffs + 1  # mu, s + alpha, beta, gamma + 1 (mixture weight)
        num_out_total = num_out * self._num_logistic_mix
        params = Dense(num_out_total)(x_out)
        params = tf.reshape(
            params,
            prefer_static.concat(  # [-1,H,W,nb mixtures, params per mixture]
                [[-1], image_shape[:-1], [self._num_logistic_mix, num_out]],
                axis=0))

        # If there is one color channel, split the parameters into a list of three
        # output `Tensor`s: (1) component logits for the Quantized Logistic mixture
        # distribution, (2) location parameters for each component, and (3) scale
        # parameters for each component. If there is more than one color channel,
        # return a fourth `Tensor` for the coefficients for the linear dependence
        # among color channels (e.g. alpha, beta, gamma).
        # [logits, mu, s, linear dependence]
        splits = 3 if num_channels == 1 else [
            1, num_channels, num_channels, num_coeffs
        ]
        outputs = tf.split(params, splits, axis=-1)

        # Squeeze singleton dimension from component logits
        outputs[0] = tf.squeeze(outputs[0], axis=-1)

        # Ensure scales are positive and do not collapse to near-zero
        outputs[2] = tf.nn.softplus(outputs[2]) + tf.cast(
            tf.exp(-7.), self.dtype)

        inputs = image_input if conditional_input is None else [
            image_input, conditional_input
        ]
        self._network = tf.keras.Model(inputs=inputs, outputs=outputs)
        super(_PixelCNNNetwork, self).build(input_shape)
def _unvec(x, matrix_shape):
    # Unvec takes in a (batch) of matrices of shape B1 + [n * k] and returns
    # a (batch) of vectors of shape B1 + [n, k], where n and k are specified
    # by matrix_shape.
    return tf.reshape(x, ps.concat([ps.shape(x)[:-1], matrix_shape], axis=0))
def convolution_batch(x,
                      kernel,
                      rank,
                      strides,
                      padding,
                      data_format=None,
                      dilations=None,
                      name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    if rank != 2:
        raise NotImplementedError(
            'Argument `rank` currently only supports `2`; '
            'saw "{}".'.format(rank))
    if data_format is not None and data_format.upper() != 'NHWBC':
        raise ValueError(
            'Argument `data_format` currently only supports "NHWBC"; '
            'saw "{}".'.format(data_format))
    with tf.name_scope(name or 'conv2d_nhwbc'):
        # Prepare arguments.
        [
            rank,
            _,  # strides
            padding,
            dilations,
            data_format,
        ] = prepare_conv_args(rank, strides, padding, dilations)
        strides = prepare_strides(strides, rank + 2, arg_name='strides')

        dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')

        # Step 1: Transpose and double flatten kernel.
        # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c']
        kernel_shape = prefer_static.shape(kernel)
        kernel_batch_shape, kernel_event_shape = prefer_static.split(
            kernel_shape, num_or_size_splits=[-1, rank + 2])
        kernel_batch_size = prefer_static.reduce_prod(kernel_batch_shape)
        kernel_ndims = prefer_static.rank(kernel)
        kernel_batch_ndims = kernel_ndims - rank - 2
        perm = prefer_static.concat([
            prefer_static.range(kernel_batch_ndims, kernel_batch_ndims + rank),
            prefer_static.range(0, kernel_batch_ndims),
            prefer_static.range(kernel_batch_ndims + rank, kernel_ndims),
        ],
                                    axis=0)  # Eg, [1, 2, 0, 3, 4]
        kernel = tf.transpose(kernel, perm=perm)  # F + B + [c, c']
        kernel = tf.reshape(kernel,
                            shape=prefer_static.concat([
                                kernel_event_shape[:rank],
                                [
                                    kernel_batch_size * kernel_event_shape[-2],
                                    kernel_event_shape[-1]
                                ],
                            ],
                                                       axis=0))  # F + [bc, c']

        # Step 2: Double flatten x.
        # x.shape = N + D + B + [c]
        x_shape = prefer_static.shape(x)
        [
            x_sample_shape,
            x_rank_shape,
            x_batch_shape,
            x_channel_shape,
        ] = prefer_static.split(
            x_shape, num_or_size_splits=[-1, rank, kernel_batch_ndims, 1])
        x = tf.reshape(
            x,  # N + D + B + [c]
            shape=prefer_static.concat([
                [prefer_static.reduce_prod(x_sample_shape)],
                x_rank_shape,
                [
                    prefer_static.reduce_prod(x_batch_shape) *
                    prefer_static.reduce_prod(x_channel_shape)
                ],
            ],
                                       axis=0))  # [n] + D + [bc]

        # Step 3: Apply convolution.
        y = tf.nn.depthwise_conv2d(x,
                                   kernel,
                                   strides=strides,
                                   padding=padding,
                                   data_format='NHWC',
                                   dilations=dilations)
        #  SAME: y.shape = [n, h,      w,      bcc']
        # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc']

        # Step 4: Reshape/reduce for output.
        y_shape = prefer_static.shape(y)
        y = tf.reshape(y,
                       shape=prefer_static.concat(
                           [
                               x_sample_shape,
                               y_shape[1:-1],
                               kernel_batch_shape,
                               kernel_event_shape[-2:],
                           ],
                           axis=0))  # N + D' + B + [c, c']
        y = tf.reduce_sum(y, axis=-2)  # N + D' + B + [c']

        return y
Beispiel #26
0
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_criterion_fn=ess_below_threshold,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        num_steps_state_history_to_pass=None,
        num_steps_observation_history_to_pass=None,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing (possibly weighted) samples from the series of filtering
      distributions `p(latent_states[t] | observations[:t])`.
    log_weights: `float` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`, such that
      `log_weights[t, :]` are the logarithms of normalized importance weights
      (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of
      the particles at time `t`. These may be used in conjunction with
      `particles` to compute expectations under the series of filtering
      distributions.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    step_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each observed timestep `t`.
      Note that (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  ${non_markovian_specification_str}
  """
    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = prefer_static.shape(
            tf.nest.flatten(observations)[0])[0]
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If no criterion is specified, default is to resample at every step.
        if not resample_criterion_fn:
            resample_criterion_fn = lambda _: True

        # Dress up the prior and prior proposal as a fake `transition_fn` and
        # `proposal_fn` respectively.
        prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_prior, num_particles)
        prior_proposal_fn = (
            None if initial_state_proposal is None else
            lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
                initial_state_proposal, num_particles))

        # Initially the particles all have the same weight, `1. / num_particles`.
        broadcast_batch_shape = tf.convert_to_tensor(functools.reduce(
            prefer_static.broadcast_shape,
            tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []),
                                                     dtype=tf.int32)
        log_uniform_weights = prefer_static.zeros(
            prefer_static.concat([[num_particles], broadcast_batch_shape],
                                 axis=0),
            dtype=tf.float32) - prefer_static.log(num_particles)

        # Initialize from the prior, and incorporate the first observation.
        initial_step_results = _filter_one_step(
            step=0,
            # `previous_particles` at the first step is a dummy quantity, used only
            # to convey state structure and num_particles to an optional
            # proposal fn.
            previous_particles=prior_fn(0, []).sample(),
            log_weights=log_uniform_weights,
            observation=tf.nest.map_structure(lambda x: tf.gather(x, 0),
                                              observations),
            transition_fn=prior_fn,
            observation_fn=observation_fn,
            proposal_fn=prior_proposal_fn,
            resample_criterion_fn=resample_criterion_fn,
            seed=seed)

        def _loop_body(step, previous_step_results, accumulated_step_results,
                       state_history):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                prefer_static.equal(num_transitions_per_observation, 1) |
                prefer_static.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            history_to_pass_into_fns = {}
            if num_steps_observation_history_to_pass:
                history_to_pass_into_fns[
                    'observation_history'] = _gather_history(
                        observations, observation_idx,
                        num_steps_observation_history_to_pass)
            if num_steps_state_history_to_pass:
                history_to_pass_into_fns['state_history'] = state_history

            new_step_results = _filter_one_step(
                step=step,
                previous_particles=previous_step_results.particles,
                log_weights=previous_step_results.log_weights,
                observation=current_observation,
                transition_fn=functools.partial(transition_fn,
                                                **history_to_pass_into_fns),
                observation_fn=functools.partial(observation_fn,
                                                 **history_to_pass_into_fns),
                proposal_fn=(None
                             if proposal_fn is None else functools.partial(
                                 proposal_fn, **history_to_pass_into_fns)),
                resample_criterion_fn=resample_criterion_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(step, new_step_results,
                                          accumulated_step_results,
                                          state_history)

        loop_results = tf.while_loop(
            cond=lambda step, *_: step < num_timesteps,
            body=_loop_body,
            loop_vars=_initialize_loop_variables(
                initial_step_results, num_steps_state_history_to_pass,
                num_timesteps))

        results = tf.nest.map_structure(lambda ta: ta.stack(),
                                        loop_results.accumulated_step_results)
        if num_transitions_per_observation != 1:
            # Return a log-prob for each observed step.
            observed_steps = prefer_static.range(
                0, num_timesteps, num_transitions_per_observation)
            results = results._replace(step_log_marginal_likelihood=tf.gather(
                results.step_log_marginal_likelihood, observed_steps))
        return results
Beispiel #27
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'one_step')):
            variance_parts = previous_kernel_results.running_variance
            diags = [
                variance_part.variance() for variance_part in variance_parts
            ]
            # Set the momentum.
            batch_ndims = ps.rank(
                unnest.get_innermost(previous_kernel_results,
                                     'target_log_prob'))
            state_parts = tf.nest.flatten(current_state)
            new_momentum_distribution = _make_momentum_distribution(
                diags, state_parts, batch_ndims)
            inner_results = self.momentum_distribution_setter_fn(
                previous_kernel_results.inner_results,
                new_momentum_distribution)

            # Step the inner kernel.
            inner_kwargs = {} if seed is None else dict(seed=seed)
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results, **inner_kwargs)
            new_state_parts = tf.nest.flatten(new_state)
            new_variance_parts = []
            for variance_part, diag, state_part in zip(variance_parts, diags,
                                                       new_state_parts):
                # Compute new variance for each variance part, accounting for partial
                # batching of the variance calculation across chains (ie, some, all, or
                # none of the chains may share the estimated mass matrix).
                #
                # For example, say
                #
                # state_part has shape       [2, 3, 4] + [5, 6]  (batch + event)
                # variance_part has shape          [4] + [5, 6]
                # log_prob has shape         [2, 3, 4]
                #
                # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
                # matrices, each being shared across a [2, 3]-batch of chains. Note this
                # division is inferred from the shapes of the state part, the log_prob,
                # and the user-provided initial running variances.
                #
                # Until RunningVariance supports rank > 1 chunking, we need to flatten
                # the states that go into updating the variance estimates. In the above
                # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
                # fed to `RunningVariance.update(state_part, axis=0)`, recording
                # 6 new observations in the running variance calculation.
                # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
                # the resulting momentum distribution will have batch shape of
                # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
                state_rank = ps.rank(state_part)
                variance_rank = ps.rank(diag)
                num_reduce_dims = state_rank - variance_rank

                state_part_shape = ps.shape(state_part)
                # This reshape adds a 1 when reduce_dims==0, and collapses all the lead
                # dimensions to a single one otherwise.
                reshaped_state = ps.reshape(
                    state_part,
                    ps.concat(
                        [[ps.reduce_prod(state_part_shape[:num_reduce_dims])],
                         state_part_shape[num_reduce_dims:]],
                        axis=0))

                # The `axis=0` here removes the leading dimension we got from the
                # reshape above, so the new_variance_parts have the correct shape again.
                new_variance_parts.append(
                    variance_part.update(reshaped_state, axis=0))

            new_kernel_results = previous_kernel_results._replace(
                inner_results=new_inner_results,
                running_variance=new_variance_parts)

            return new_state, new_kernel_results
Beispiel #28
0
 def _batch_shape_tensor(self, **kwargs):
     return tf.nest.map_structure(
         lambda b: prefer_static.concat([[self.num_particles], b], axis=0),
         self.distribution.batch_shape_tensor(**kwargs))
Beispiel #29
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        df = tf.convert_to_tensor(self.df)
        batch_shape = self._batch_shape_tensor(df)
        event_shape = self._event_shape_tensor()
        dimension = self._dimension()
        x_ndims = ps.rank(x_sqrt)
        num_singleton_axes_to_prepend = (
            ps.maximum(ps.size(batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = ps.concat([
            ps.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            ps.shape(x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = ps.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - ps.size(batch_shape) - 2
        sample_shape = ps.shape(x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = ps.concat(
            [ps.range(sample_ndims, ndims),
             ps.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            ps.cast(dimension, dtype=tf.int32) *
            ps.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims]))
        shape = ps.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [ps.cast(dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self._scale.solve(scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = ps.concat(
            [ps.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape],
            axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = ps.concat([
            ps.range(ndims - sample_ndims, ndims),
            ps.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt),
                                          axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((df - dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x -
                    self._log_normalization(df=df, scale=self._scale))

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            tensorshape_util.set_shape(
                log_prob,
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob
Beispiel #30
0
def resample_independent(log_probs,
                         event_size,
                         sample_shape,
                         seed=None,
                         name=None):
    """Categorical resampler for sequential Monte Carlo.

  This function is based on Algorithm #1 in the paper
  [Maskell et al. (2006)][1].

  Args:
    log_probs: A tensor-valued batch of discrete log probability distributions.
    event_size: the dimension of the vector considered a single draw.
    sample_shape: the `sample_shape` determining the number of draws.
    seed: Python '`int` used to seed calls to `tf.random.*`.
      Default value: None (i.e. no seed).
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'resample_independent'`).

  Returns:
    resampled_indices: The result is similar to sampling with
    ```python
    expanded_sample_shape = tf.concat([[event_size], sample_shape]), axis=-1)
    tfd.Categorical(logits=log_probs).sample(expanded_sample_shape)`
    ```
    but with values sorted along the first axis. It can be considered to be
    sampling events made up of a length-`event_size` vector of draws from
    the `Categorical` distribution. For large input values this function should
    give better performance than using `Categorical`.
    The sortedness is an unintended side effect of the algorithm that is
    harmless in the context of simple SMC algorithms.

  #### References

  [1]: S. Maskell, B. Alun-Jones and M. Macleod. A Single Instruction Multiple
       Data Particle Filter.
       In 2006 IEEE Nonlinear Statistical Signal Processing Workshop.
       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf

  """
    with tf.name_scope(name or 'resample_independent') as name:
        log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32)
        log_probs = dist_util.move_dimension(log_probs,
                                             source_idx=0,
                                             dest_idx=-1)

        batch_shape = prefer_static.shape(log_probs)[:-1]
        num_markers = prefer_static.shape(log_probs)[-1]

        # `working_shape` specifies the total number of events
        # we will be generating.
        working_shape = prefer_static.concat([sample_shape, batch_shape],
                                             axis=0)
        # `points_shape` is the shape of the final result.
        points_shape = prefer_static.concat([working_shape, [event_size]],
                                            axis=0)
        # `markers_shape` is the shape of the markers we temporarily insert.
        markers_shape = prefer_static.concat([working_shape, [num_markers]],
                                             axis=0)
        # Generate one real point for each particle.
        log_points = -exponential.Exponential(
            rate=tf.constant(1.0, dtype=log_probs.dtype)).sample(points_shape,
                                                                 seed=seed)

        # We divide up the unit interval [0, 1] according to the provided
        # probability distributions using `cumsum`.
        # At the end of each division we place a 'marker'.
        # We generate random points on the unit interval.
        # We sort the combination of points and markers. The number
        # of points between the markers defining a division gives the number
        # of samples we require in that division.
        # For example, suppose `probs` is `[0.2, 0.3, 0.5]`.
        # We divide up `[0, 1]` using 3 markers:
        #
        #     |     |          |
        # 0.  0.2   0.5        1.0  <- markers
        #
        # Suppose we generate four points: [0.1, 0.25, 0.9, 0.75]
        # After sorting the combination we get:
        #
        # 0.1  0.25     0.75 0.9    <- points
        #  *  | *   |    *    *|
        # 0.   0.2 0.5         1.0  <- markers
        #
        # We have one sample in the first category, one in the second and
        # two in the last.
        #
        # All of these computations are carried out in batched form.
        markers = prefer_static.concat([
            tf.zeros(points_shape, dtype=tf.int32),
            tf.ones(markers_shape, dtype=tf.int32)
        ],
                                       axis=-1)
        log_marker_positions = tf.broadcast_to(
            tf.math.cumulative_logsumexp(log_probs, axis=-1), markers_shape)
        log_points_and_markers = prefer_static.concat(
            [log_points, log_marker_positions], axis=-1)
        indices = tf.argsort(log_points_and_markers, axis=-1, stable=False)
        sorted_markers = tf.gather_nd(
            markers,
            indices[..., tf.newaxis],
            batch_dims=(prefer_static.rank_from_shape(sample_shape) +
                        prefer_static.rank_from_shape(batch_shape)))
        markers_and_samples = prefer_static.cast(tf.cumsum(sorted_markers,
                                                           axis=-1),
                                                 dtype=tf.int32)
        markers_and_samples = tf.minimum(markers_and_samples, num_markers - 1)
        # Collect up samples, omitting markers.
        resampled = tf.reshape(
            markers_and_samples[tf.equal(sorted_markers, 0)], points_shape)
        resampled = dist_util.move_dimension(resampled,
                                             source_idx=-1,
                                             dest_idx=0)
        return resampled