Ejemplo n.º 1
0
 def _log_prob(self, x):
   batch_ndims = prefer_static.rank_from_shape(
       self.distribution.batch_shape_tensor,
       self.distribution.batch_shape)
   extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
   event_ndims = prefer_static.rank_from_shape(
       self.distribution.event_shape_tensor,
       self.distribution.event_shape)
   ndims = prefer_static.rank(x)
   # (1) Expand x's dims.
   d = ndims - batch_ndims - extra_sample_ndims - event_ndims
   x = tf.reshape(x, shape=tf.pad(
       tensor=tf.shape(input=x),
       paddings=[[prefer_static.maximum(0, -d), 0]],
       constant_values=1))
   sample_ndims = prefer_static.maximum(0, d)
   # (2) Transpose x's dims.
   sample_dims = prefer_static.range(0, sample_ndims)
   batch_dims = prefer_static.range(sample_ndims, sample_ndims + batch_ndims)
   extra_sample_dims = prefer_static.range(
       sample_ndims + batch_ndims,
       sample_ndims + batch_ndims + extra_sample_ndims)
   event_dims = prefer_static.range(
       sample_ndims + batch_ndims + extra_sample_ndims,
       ndims)
   perm = prefer_static.concat(
       [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
   x = tf.transpose(a=x, perm=perm)
   # (3) Compute x's log_prob.
   lp = self.distribution.log_prob(x)
   # (4) Make the final reduction in x.
   axis = prefer_static.range(sample_ndims, sample_ndims + extra_sample_ndims)
   return tf.reduce_sum(input_tensor=lp, axis=axis)
Ejemplo n.º 2
0
 def _prepare_for_underlying(self, x):
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor,
                                      self.distribution.event_shape)
     ndims = ps.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=ps.pad(ps.shape(x),
                                 paddings=[[ps.maximum(0, -d), 0]],
                                 constant_values=1))
     ndims = ps.rank(x)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(x, perm=perm)
     return x, (sample_ndims, extra_sample_ndims, batch_ndims)
Ejemplo n.º 3
0
        def loop_body(i_, event_ind):
            i = i_ // strides
            j = i_ % strides

            i_ind = ps.range(i * fw,
                             ps.maximum(i, fh) * fw,
                             delta=strides * fw,
                             dtype=dtype)
            j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype)

            nc = cartesian_add([i_ind, j_ind])
            ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0])

            k = ps.reshape(cartesian_add([
                ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype),
                ps.range(ps.shape(nc)[1], dtype=dtype)
            ]),
                           shape=[-1])
            last_j = strides - (fw - j - 1) % strides - 1
            last_i = strides - (fh - i - 1) % strides - 1
            kernel_ind = ps.stack(
                [k, ps.ones_like(k) * last_i * strides + last_j], axis=1)
            event_ind = ps.tensor_scatter_nd_update(event_ind, ind[...,
                                                                   tf.newaxis],
                                                    kernel_ind)

            return i_ + 1, event_ind
Ejemplo n.º 4
0
  def _split_and_reshape_event(self, x):
    event_tensors = self._distribution.event_shape_tensor()
    splits = [
        ps.maximum(1, ps.reduce_prod(s))
        for s in tf.nest.flatten(event_tensors)
    ]
    x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1))

    def _reshape_part(part, dtype, event_shape):
      part = tf.cast(part, dtype)
      static_rank = tf.get_static_value(ps.rank_from_shape(event_shape))
      if static_rank == 1:
        return part
      new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1)
      return tf.reshape(part, ps.cast(new_shape, tf.int32))

    if all(
        tensorshape_util.is_fully_defined(s)
        for s in tf.nest.flatten(self._distribution.event_shape)):
      x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype,
                                self._distribution.event_shape)
    else:
      x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype,
                                self._distribution.event_shape_tensor())
    return x
Ejemplo n.º 5
0
            def log_joint_fn(*param_vals, **param_kwargs):
                """Generated log-density function."""

                if param_kwargs:
                    if param_vals:
                        raise ValueError(
                            'log_joint_fn saw both positional args ({}) and named args ({}). '
                            'This is not supported: you have to choose!'.
                            format(param_vals, param_kwargs))
                    param_vals = [
                        param_kwargs[p.name] for p in self.parameters
                    ]

                param_lp = parameter_prior.log_prob(*param_vals)

                # Build a linear Gaussian state space model and evaluate the marginal
                # log_prob on observations.
                lgssm = self.make_state_space_model(
                    param_vals=param_vals, num_timesteps=num_timesteps)
                observation_lp = lgssm.log_prob(observed_time_series,
                                                mask=mask)

                # Sum over likelihoods from iid observations. Without this sum,
                # adding `param_lp + observation_lp` would broadcast the param priors
                # over the sample shape, which incorrectly multi-counts the param
                # priors.
                sample_ndims = ps.maximum(
                    0,
                    ps.rank(observation_lp) - ps.rank(param_lp))
                observation_lp = tf.reduce_sum(observation_lp,
                                               axis=ps.range(sample_ndims))

                return param_lp + observation_lp
Ejemplo n.º 6
0
    def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None):
        if event_shape is None:
            event_shape = self.event_shape_tensor()

        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)

        # Continuing the example from `_augment_sample_shape`, suppose we have:
        #   - sample shape of `[n]`,
        #   - underlying distribution batch shape of `[2, 1]`,
        #   - final broadcast batch shape of `[4, 2, 3]`.
        # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we
        # ultimately want to have shape `[n, 4, 2, 3] + event_shape`.

        # First, we reshape to expand out the batch elements:
        # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`,
        # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and
        # `[4, 1, 3]` is the shape of the elements being added by broadcasting.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)
        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)
        x_with_doubled_batch = tf.reshape(
            x,
            ps.concat([
                sample_shape,
                ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp,
                event_shape
            ],
                      axis=0))

        # Next, construct the permutation that interleaves the batch dimensions,
        # resulting in samples with shape
        # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`.
        # Note that each interleaved pair of batch dimensions contains exactly one
        # dim of size `1` and one of size `>= 1`.
        sample_ndims = ps.rank_from_shape(sample_shape)
        x_with_interleaved_batch = tf.transpose(
            x_with_doubled_batch,
            perm=ps.concat([
                ps.range(sample_ndims), sample_ndims + ps.reshape(
                    ps.stack([
                        ps.range(batch_rank),
                        ps.range(batch_rank) + batch_rank
                    ],
                             axis=-1), [-1]), sample_ndims + 2 * batch_rank +
                ps.range(ps.rank_from_shape(event_shape))
            ],
                           axis=0))

        # Final reshape to remove the spurious `1` dimensions.
        return tf.reshape(
            x_with_interleaved_batch,
            ps.concat([sample_shape, batch_shape, event_shape], axis=0))
Ejemplo n.º 7
0
def left_justified_expand_dims_to(x, rank, name=None):
  """Right pads `x` with `rank - rank(x)` ones."""
  with tf.name_scope(name or 'left_justified_expand_dims_to'):
    expand_ndims = ps.maximum(rank - ps.rank(x), 0)
    expand_shape = ps.concat(
        [ps.shape(x),
         ps.ones(shape=[expand_ndims], dtype=tf.int32)],
        axis=0)
    return tf.reshape(x, expand_shape)
Ejemplo n.º 8
0
def left_justified_expand_dims_to(x, rank, name=None):
    """Right pads `x` with `rank - rank(x)` ones."""
    with tf.name_scope(name or 'left_justified_expand_dims_to'):
        rank = tf.convert_to_tensor(rank, dtype=tf.int32)
        expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0)
        expand_shape = prefer_static.pad(prefer_static.shape(x),
                                         paddings=[[0, expand_ndims]],
                                         constant_values=1)
        return prefer_static.reshape(x, expand_shape)
Ejemplo n.º 9
0
        def loop_body(i_, kernels_ind):
            i = i_ // sw
            j = i_ % sw

            i_ind = ps.range(i * fw,
                             ps.maximum(i, fh) * fw,
                             delta=sh * fw,
                             dtype=dtype)
            j_ind = ps.range(j, ps.maximum(j, fw), delta=sw, dtype=dtype)

            last_j = sw - (fw - j - 1) % sw - 1
            last_i = sh - (fh - i - 1) % sh - 1
            pos = last_i * sw + last_j

            nc = cartesian_add([i_ind, j_ind])
            kernels_ind = kernels_ind.write(
                pos, ps.reverse(ps.reverse(nc, [0]), [1]))
            return i_ + 1, kernels_ind
Ejemplo n.º 10
0
 def _prepare_for_underlying(self, x):
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor,
                                      self.distribution.event_shape)
     ndims = ps.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=ps.pad(ps.shape(x),
                                 paddings=[[ps.maximum(0, -d), 0]],
                                 constant_values=1))
     sample_ndims = ps.maximum(0, d)
     x = tf.transpose(x,
                      perm=ps.invert_permutation(
                          self._sampling_permutation(sample_ndims)))
     return x, (sample_ndims, extra_sample_ndims, batch_ndims)
Ejemplo n.º 11
0
def _rightmost_expand_to_rank(tensor, new_rank):
  """Expands `tensor`'s rank by `new_rank - tensor.rank` rightmost dims."""
  return tf.reshape(
      tensor,
      shape=prefer_static.pad(
          prefer_static.shape(tensor),
          paddings=[[0,
                     prefer_static.maximum(
                         new_rank - prefer_static.rank(tensor), 0)]],
          constant_values=1))
Ejemplo n.º 12
0
 def update_event_ndims(input_event_ndims, input_min_event_ndims,
                        output_min_event_ndims):
     """Returns output_event_ndims and updates rolling_offset as needed."""
     nonlocal rolling_offset
     ldj_reduce_ndims = bijector_lib.ldj_reduction_ndims(
         input_event_ndims, input_min_event_ndims)
     # Update rolling_offset when batch_ndims are negative.
     rolling_offset = ps.maximum(rolling_offset, -ldj_reduce_ndims)
     return nest.map_structure(lambda nd: ldj_reduce_ndims + nd,
                               output_min_event_ndims)
Ejemplo n.º 13
0
    def _sample_n(self, n, seed=None):
        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)
        n_batch = ps.reduce_prod(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)
        underlying_n_batch = ps.reduce_prod(underlying_batch_shape)

        # Left pad underlying shape with any necessary ones.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)

        # Determine how many underlying samples to produce.
        n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch)
        samps = self.distribution.sample([n, n_bcast_samples], seed=seed)

        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)

        event_shape = self.event_shape_tensor()
        event_rank = ps.rank_from_shape(event_shape)
        shp = ps.concat([[n],
                         ps.where(is_dim_bcast, batch_shape, 1),
                         underlying_bcast_shp, event_shape],
                        axis=0)
        # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp.
        samps = tf.reshape(samps, shp)
        # Interleave broadcast and underlying axis indices for transpose.
        interleaved_batch_axes = ps.reshape(
            ps.stack([ps.range(batch_rank),
                      ps.range(batch_rank) + batch_rank],
                     axis=-1), [-1]) + 1

        event_axes = ps.range(event_rank) + (1 + 2 * batch_rank)
        perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0)
        samps = tf.transpose(samps, perm=perm)
        # Finally, reshape to the fully-broadcast batch shape.
        return tf.reshape(samps,
                          ps.concat([[n], batch_shape, event_shape], axis=0))
Ejemplo n.º 14
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor,
                                      self.distribution.event_shape)
     ndims = ps.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=ps.pad(ps.shape(x),
                                 paddings=[[ps.maximum(0, -d), 0]],
                                 constant_values=1))
     ndims = ps.rank(x)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has
     #     full sample shape in the sample axes, before we reduce.
     bcast_lp_shape = ps.broadcast_shape(
         ps.shape(lp),
         ps.concat([
             ps.ones([sample_ndims], tf.int32),
             ps.reshape(self.sample_shape, shape=[-1]),
             ps.ones([batch_ndims], tf.int32)
         ],
                   axis=0))
     lp = tf.broadcast_to(lp, bcast_lp_shape)
     # (5) Make the final reduction in x.
     axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
Ejemplo n.º 15
0
def left_justified_expand_dims_to(x, rank, name=None):
    """Right pads `x` with `rank - rank(x)` ones."""
    with tf.name_scope(name or 'left_justified_expand_dims_to'):
        rank = tf.convert_to_tensor(rank, dtype=tf.int32)
        expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0)
        expand_shape = prefer_static.concat([
            prefer_static.shape(x),
            prefer_static.ones(shape=[expand_ndims], dtype=tf.int32)
        ],
                                            axis=0)
        return prefer_static.reshape(x, expand_shape)
Ejemplo n.º 16
0
def _get_transpose_conv_dilated_padding(filter_dim, stride, dilation, padding):
    """Zero-padding for inputs dilated by strides."""
    tot_filter_dim = filter_dim + (filter_dim - 1) * (dilation - 1)
    if padding == 'VALID':
        tot_pad = tot_filter_dim + stride - 2 + ps.maximum(
            tot_filter_dim - stride, 0)
    elif padding == 'SAME':
        tot_pad = tot_filter_dim + stride - 2
    return ps.cond(filter_dim >= stride, lambda:
                   (tot_pad - tot_pad // 2 - stride + 1, tot_pad // 2), lambda:
                   (filter_dim - stride, tot_pad - filter_dim + 1))
 def augmented_fn(step, *args, **kwargs):
     with tf.name_scope('augment_with_observation_history'):
         observation_idx = step // num_transitions_per_observation
         observation_history_indices = ps.range(
             ps.maximum(0, observation_idx - history_size),
             observation_idx)
         return fn(step,
                   *args,
                   observation_history=tf.gather(
                       observations, observation_history_indices),
                   **kwargs)
Ejemplo n.º 18
0
  def _get_reinterpreted_batch_ndims(self,
                                     distribution_batch_shape_tensor=None):
    if self._static_reinterpreted_batch_ndims is not None:
      return self._static_reinterpreted_batch_ndims
    if self._reinterpreted_batch_ndims is not None:
      return tf.convert_to_tensor(self._reinterpreted_batch_ndims)

    if distribution_batch_shape_tensor is None:
      distribution_batch_shape_tensor = self.distribution.batch_shape_tensor()
    return ps.cast(
        ps.maximum(0, ps.size(distribution_batch_shape_tensor) - 1),
        np.int32)
Ejemplo n.º 19
0
 def _augment_sample_shape(self, sample_shape):
     # Suppose we have:
     #   - sample shape of `[n]`,
     #   - underlying distribution batch shape of `[2, 1]`,
     #   - final broadcast batch shape of `[4, 2, 3]`.
     # Then we must draw `sample_shape + [12]` samples, where
     # `12 == n_batch // underlying_n_batch`.
     batch_shape = self.batch_shape_tensor()
     n_batch = ps.reduce_prod(batch_shape)
     underlying_batch_shape = self.distribution.batch_shape_tensor()
     underlying_n_batch = ps.reduce_prod(underlying_batch_shape)
     return ps.concat(
         [sample_shape, [ps.maximum(0, n_batch // underlying_n_batch)]],
         axis=0)
Ejemplo n.º 20
0
def _split_and_reshape_event(x, model):
  """Splits and reshapes a flat event `x` to match the structure of `model`."""
  splits = [
      ps.maximum(1, ps.reduce_prod(s))
      for s in tf.nest.flatten(model.event_shape)
  ]
  x = tf.nest.pack_sequence_as(model.event_shape, tf.split(x, splits, axis=-1))

  def _reshape_part(part, dtype, event_shape):
    part = tf.cast(part, dtype)
    new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1)
    return tf.reshape(part, ps.cast(new_shape, tf.int32))

  x = tf.nest.map_structure(_reshape_part, x, model.dtype, model.event_shape)
  return x
Ejemplo n.º 21
0
    def _split_and_reshape_event(self, x):
        splits = [
            ps.maximum(1, ps.reduce_prod(s))
            for s in tf.nest.flatten(self._model.event_shape)
        ]
        x = tf.nest.pack_sequence_as(self._model.event_shape,
                                     tf.split(x, splits, axis=-1))

        def _reshape_part(part, dtype, event_shape):
            part = tf.cast(part, dtype)
            new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1)
            return tf.reshape(part, ps.cast(new_shape, tf.int32))

        x = tf.nest.map_structure(_reshape_part, x, self._model.dtype,
                                  self._model.event_shape)
        return x
Ejemplo n.º 22
0
def _compute_fans_from_shape(shape, batch_ndims=0):
  """Extracts `fan_in, fan_out` from specified shape `Tensor`."""
  # Ensure shape is a vector of length >=2.
  num_pad = prefer_static.maximum(0, 2 - prefer_static.size(shape))
  shape = prefer_static.pad(
      shape, paddings=[[0, num_pad]], constant_values=1)
  (
      batch_shape,  # pylint: disable=unused-variable
      extra_shape,
      fan_in,
      fan_out,
  ) = prefer_static.split(shape, [batch_ndims, -1, 1, 1])
  # The following logic is primarily intended for convolutional layers which
  # have spatial semantics in addition to input/output channels.
  receptive_field_size = prefer_static.reduce_prod(extra_shape)
  fan_in = fan_in[0] * receptive_field_size
  fan_out = fan_out[0] * receptive_field_size
  return fan_in, fan_out
Ejemplo n.º 23
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     extra_batch_ndims = prefer_static.rank_from_shape(self.batch_stack)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     ndims = prefer_static.rank(x)
     # (1) Expand x's dims.
     d = ndims - extra_batch_ndims - batch_ndims - event_ndims
     x = tf.reshape(
         x,
         shape=tf.pad(tf.shape(x),
                      paddings=[[prefer_static.maximum(0, -d), 0]],
                      constant_values=1),
     )
     # (2) Compute x's log_prob.
     return self.distribution.log_prob(x, **kwargs)
Ejemplo n.º 24
0
    def _log_prob(self, x):
        assertions = []
        message = 'Input must have at least one dimension.'
        if tensorshape_util.rank(x.shape) is not None:
            if tensorshape_util.rank(x.shape) == 0:
                raise ValueError(message)
        elif self.validate_args:
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=message))
        with tf.control_dependencies(assertions):
            event_tensors = self._distribution.event_shape_tensor()
            splits = [
                ps.maximum(1, ps.reduce_prod(s))
                for s in tf.nest.flatten(event_tensors)
            ]
            x = tf.nest.pack_sequence_as(event_tensors,
                                         tf.split(x, splits, axis=-1))

            def _reshape_part(part, dtype, event_shape):
                part = tf.cast(part, dtype)
                static_rank = tf.get_static_value(
                    ps.rank_from_shape(event_shape))
                if static_rank == 1:
                    return part
                new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                                      axis=-1)
                return tf.reshape(part, ps.cast(new_shape, tf.int32))

            if all(
                    tensorshape_util.is_fully_defined(s)
                    for s in tf.nest.flatten(self._distribution.event_shape)):
                x = tf.nest.map_structure(_reshape_part, x,
                                          self._distribution.dtype,
                                          self._distribution.event_shape)
            else:
                x = tf.nest.map_structure(
                    _reshape_part, x, self._distribution.dtype,
                    self._distribution.event_shape_tensor())

            return self._distribution.log_prob(x)
Ejemplo n.º 25
0
 def _calculate_new_shape(self):
     # Try to get the old shape statically if available.
     original_shape = self._distribution.batch_shape
     if not tensorshape_util.is_fully_defined(original_shape):
         original_shape = self._distribution.batch_shape_tensor()
     # This is not a check for falseness, it's a check for exactly that shape.
     if original_shape == ():  # pylint: disable=g-explicit-bool-comparison
         # Force the size to be an integer, not a float, when the shape contains no
         # dtype information.
         original_size = 1
     else:
         original_size = ps.reduce_prod(original_shape)
     original_size = ps.cast(original_size, tf.int32)
     # Compute the new shape, filling in the `-1` dimension if present.
     new_shape = self._batch_shape_unexpanded
     implicit_dim_mask = ps.equal(new_shape, -1)
     size_implicit_dim = (original_size //
                          ps.maximum(1, -ps.reduce_prod(new_shape)))
     expanded_new_shape = ps.where(  # Assumes exactly one `-1`.
         implicit_dim_mask, size_implicit_dim, new_shape)
     # Return the original size on the side because one caller would otherwise
     # have to recompute it.
     return expanded_new_shape, original_size
Ejemplo n.º 26
0
def _deconv_output_length(input_size, filter_size, padding, output_padding,
                          stride, dilation):
    """Determines output length of a transposed convolution given input length.

  Args:
    input_size: `int`.
    filter_size: `int`.
    padding: one of `"SAME"`, `"VALID"`, `"FULL"`.
    output_padding: `int`, amount of padding along the output dimension. Can
      be set to `None` in which case the output length is inferred.
    stride: `int`.
    dilation: `int`.

  Returns:
    output_length: The output length (`int`).
  """
    assert padding in {'SAME', 'VALID', 'FULL'}
    if input_size is None:
        return None
    # Get the dilated kernel size
    filter_size = filter_size + (filter_size - 1) * (dilation - 1)
    # Infer length if output padding is None, else compute the exact length
    if output_padding is None:
        if padding == 'VALID':
            return input_size * stride + ps.maximum(filter_size - stride, 0)
        elif padding == 'FULL':
            return input_size * stride - (stride + filter_size - 2)
        elif padding == 'SAME':
            return input_size * stride
    if padding == 'SAME':
        pad = filter_size // 2
    elif padding == 'VALID':
        pad = 0
    elif padding == 'FULL':
        pad = filter_size - 1
    return (input_size - 1) * stride + filter_size - 2 * pad + output_padding
    def one_step(self, state, kernel_results, seed=None):
        """Takes one Sequential Monte Carlo inference step.

    Args:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        the current particles with (log) weights. The `log_weights` must be
        a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The
        `particles` may be any structure of `Tensor`s, each of which
        must have shape `concat([log_weights.shape, event_shape])` for some
        `event_shape`, which may vary across components.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results
        from a previous step.
      seed: Optional seed for reproducible sampling.

    Returns:
      state: instance of `tfp.experimental.mcmc.WeightedParticles` representing
        new particles with (log) weights.
      kernel_results: instance of
        `tfp.experimental.mcmc.SequentialMonteCarloResults`.
    """
        with tf.name_scope(self.name):
            with tf.name_scope('one_step'):
                seed = samplers.sanitize_seed(seed)
                proposal_seed, resample_seed = samplers.split_seed(seed)

                state = WeightedParticles(*state)  # Canonicalize.
                num_particles = ps.size0(state.log_weights)

                # Propose new particles and update weights for this step, unless it's
                # the initial step, in which case, use the user-provided initial
                # particles and weights.
                proposed_state = self.propose_and_update_log_weights_fn(
                    # Propose state[t] from state[t - 1].
                    ps.maximum(0, kernel_results.steps - 1),
                    state,
                    seed=proposal_seed)
                is_initial_step = ps.equal(kernel_results.steps, 0)
                # TODO(davmre): this `where` assumes the state size didn't change.
                state = tf.nest.map_structure(
                    lambda a, b: tf.where(is_initial_step, a, b), state,
                    proposed_state)

                normalized_log_weights = tf.nn.log_softmax(state.log_weights,
                                                           axis=0)
                # Every entry of `log_weights` differs from `normalized_log_weights`
                # by the same normalizing constant. We extract that constant by
                # examining an arbitrary entry.
                incremental_log_marginal_likelihood = (
                    state.log_weights[0] - normalized_log_weights[0])

                do_resample = self.resample_criterion_fn(state)

                # Some batch elements may require resampling and others not, so
                # we first do the resampling for all elements, then select whether to
                # use the resampled values for each batch element according to
                # `do_resample`. If there were no batching, we might prefer to use
                # `tf.cond` to avoid the resampling computation on steps where it's not
                # needed---but we're ultimately interested in adaptive resampling
                # for statistical (not computational) purposes, so this isn't a
                # dealbreaker.
                resampled_particles, resample_indices = weighted_resampling.resample(
                    state.particles,
                    state.log_weights,
                    self.resample_fn,
                    seed=resample_seed)
                uniform_weights = tf.fill(
                    ps.shape(state.log_weights),
                    value=-tf.math.log(
                        tf.cast(num_particles, state.log_weights.dtype)))
                (resampled_particles, resample_indices,
                 log_weights) = tf.nest.map_structure(
                     lambda r, p: ps.where(do_resample, r, p),
                     (resampled_particles, resample_indices, uniform_weights),
                     (state.particles, _dummy_indices_like(resample_indices),
                      normalized_log_weights))

            return (
                WeightedParticles(particles=resampled_particles,
                                  log_weights=log_weights),
                SequentialMonteCarloResults(
                    steps=kernel_results.steps + 1,
                    parent_indices=resample_indices,
                    incremental_log_marginal_likelihood=(
                        incremental_log_marginal_likelihood),
                    accumulated_log_marginal_likelihood=(
                        kernel_results.accumulated_log_marginal_likelihood +
                        incremental_log_marginal_likelihood),
                    seed=seed))
Ejemplo n.º 28
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        df = tf.convert_to_tensor(self.df)
        batch_shape = self._batch_shape_tensor(df)
        event_shape = self._event_shape_tensor()
        dimension = self._dimension()
        x_ndims = ps.rank(x_sqrt)
        num_singleton_axes_to_prepend = (
            ps.maximum(ps.size(batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = ps.concat([
            ps.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            ps.shape(x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = ps.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - ps.size(batch_shape) - 2
        sample_shape = ps.shape(x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = ps.concat(
            [ps.range(sample_ndims, ndims),
             ps.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            ps.cast(dimension, dtype=tf.int32) *
            ps.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims]))
        shape = ps.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [ps.cast(dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self._scale.solve(scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = ps.concat(
            [ps.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape],
            axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = ps.concat([
            ps.range(ndims - sample_ndims, ndims),
            ps.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt),
                                          axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((df - dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x -
                    self._log_normalization(df=df, scale=self._scale))

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            tensorshape_util.set_shape(
                log_prob,
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob
Ejemplo n.º 29
0
 def infected(new_infections, new_recoveries):
     return tfd.Deterministic(
         prefer_static.maximum(
             0., previous_state['infected'] + new_infections -
             new_recoveries))
Ejemplo n.º 30
0
 def susceptible(new_infections):
     return tfd.Deterministic(
         prefer_static.maximum(
             0., previous_state['susceptible'] - new_infections))