def consume(running_stat, elems, chunk_axis=None):
  def body(running_stat, elem):
    if chunk_axis is None:
      return running_stat.update(elem)
    else:
      return running_stat.update(elem, axis=chunk_axis)
  return tf.foldl(body, elems, running_stat)
    def _sample_n(self, n, seed=None):
        distribution0 = self._get_distribution0()

        if self._num_steps is not None:
            num_steps = tf.convert_to_tensor(self._num_steps)
            num_steps_static = tf.get_static_value(num_steps)
        else:
            num_steps_static = tensorshape_util.num_elements(
                distribution0.event_shape)
            if num_steps_static is None:
                num_steps = tf.reduce_prod(distribution0.event_shape_tensor())

        seed = SeedStream(seed, salt='Autoregressive')()
        samples = distribution0.sample(n, seed=seed)
        if num_steps_static is not None:
            for _ in range(num_steps_static):
                # pylint: disable=not-callable
                samples = self.distribution_fn(samples).sample(seed=seed)
        else:
            samples = tf.foldl(
                # pylint: disable=not-callable
                lambda s, _: self.distribution_fn(s).sample(seed=seed),
                elems=tf.range(0, num_steps),
                initializer=samples)
        return samples
Example #3
0
def no_pivot_ldl(matrix, name='no_pivot_ldl'):
  """Non-pivoted batched LDL factorization.

  Performs the LDL factorization, using the outer product algorithm from [1]. No
  pivoting (or block pivoting) is done, so this should be less stable than
  e.g. Bunch-Kaufman sytrf. This is implemented as a tf.foldl, so should have
  gradients and be accelerator-friendly, but is not particularly performant.

  If compiling with XLA, make sure any surrounding GradientTape is also
  XLA-compiled (b/193584244).

  #### References
  [1]: Gene H. Golub, Charles F. Van Loan. Matrix Computations, 4th ed., 2013.

  Args:
    matrix: A batch of symmetric square matrices, with shape `[..., n, n]`.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: 'no_pivot_ldl'.

  Returns:
    triangular_factor: The unit lower triangular L factor of the LDL
      factorization of `matrix`, with the same shape `[..., n, n]`. Callers
      should check for `nans` and other indicators of instability.
    diag: The diagonal from the LDL factorization, with shape `[..., n]`.
  """
  with tf.name_scope(name) as name:
    matrix = tf.convert_to_tensor(matrix)
    triangular_factor = tf.linalg.band_part(matrix, num_lower=-1, num_upper=0)
    # TODO(b/182276317) Deal with dynamic ranks better.
    slix = _Slice2Idx(triangular_factor)

    def fn(triangular_factor, i):
      column_head = triangular_factor[..., i, i, tf.newaxis]
      column_tail = triangular_factor[..., i+1:, i]
      rescaled_tail = column_tail / column_head
      triangular_factor = tf.tensor_scatter_nd_update(
          triangular_factor,
          slix[..., i+1:, i],
          rescaled_tail)
      triangular_factor = tf.tensor_scatter_nd_sub(
          triangular_factor,
          slix[..., i+1:, i+1:],
          tf.linalg.band_part(
              tf.einsum('...i,...j->...ij', column_tail, rescaled_tail),
              num_lower=-1, num_upper=0))
      return triangular_factor

    triangular_factor = tf.foldl(
        fn=fn,
        elems=tf.range(tf.shape(triangular_factor)[-1]),
        initializer=triangular_factor)

    diag = tf.linalg.diag_part(triangular_factor)
    triangular_factor = tf.linalg.set_diag(
        triangular_factor, tf.ones_like(diag))

    return triangular_factor, diag
Example #4
0
 def _do_flips():
   state = sampler._initialize_sampler_state(
       targets=targets,
       nonzeros=initial_nonzeros,
       observation_noise_variance=1.)
   def _do_flip(state, i):
     new_state = sampler._flip_feature(state, tf.gather(flip_idxs, i))
     return mcmc_util.choose(tf.gather(should_flip, i), new_state, state)
   return tf.foldl(_do_flip, elems=tf.range(num_flips), initializer=state)
Example #5
0
    def _log_prob(self, value):
        # The argument `value` is a tensor of sequences of observations.
        # `observation_batch_shape` is the shape of that tensor with the
        # sequence part removed.
        # `observation_batch_shape` is then broadcast to the full batch shape
        # to give the `batch_shape` that defines the shape of the result.

        observation_tensor_shape = tf.shape(value)
        observation_distribution = self.observation_distribution
        underlying_event_rank = tf.size(
            observation_distribution.event_shape_tensor())
        observation_batch_shape = observation_tensor_shape[:-1 -
                                                           underlying_event_rank]
        # value :: observation_batch_shape num_steps observation_event_shape
        batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())
        num_states = self.transition_distribution.batch_shape_tensor()[-1]
        log_init = _extract_log_probs(num_states, self.initial_distribution)
        # log_init :: batch_shape num_states
        log_init = tf.broadcast_to(
            log_init, tf.concat([batch_shape, [num_states]], axis=0))
        log_transition = _extract_log_probs(num_states,
                                            self.transition_distribution)

        # `observation_event_shape` is the shape of each sequence of observations
        # emitted by the model.
        observation_event_shape = observation_tensor_shape[
            -1 - underlying_event_rank:]
        working_obs = tf.broadcast_to(
            value, tf.concat([batch_shape, observation_event_shape], axis=0))
        # working_obs :: batch_shape observation_event_shape
        r = underlying_event_rank

        # Move index into sequence of observations to front so we can apply
        # tf.foldl
        working_obs = distribution_util.move_dimension(working_obs, -1 - r, 0)
        # working_obs :: num_steps batch_shape underlying_event_shape
        working_obs = tf.expand_dims(working_obs, -1 - r)
        # working_obs :: num_steps batch_shape 1 underlying_event_shape

        observation_probs = observation_distribution.log_prob(working_obs)

        # observation_probs :: num_steps batch_shape num_states

        def forward_step(log_prev_step, log_prob_observation):
            return _log_vector_matrix(log_prev_step,
                                      log_transition) + log_prob_observation

        fwd_prob = tf.foldl(forward_step,
                            observation_probs,
                            initializer=log_init)
        # fwd_prob :: batch_shape num_states

        log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
        # log_prob :: batch_shape

        return log_prob
Example #6
0
    def _log_prob(self, value):
        with tf.control_dependencies(self._runtime_assertions):
            # The argument `value` is a tensor of sequences of observations.
            # `observation_batch_shape` is the shape of that tensor with the
            # sequence part removed.
            # `observation_batch_shape` is then broadcast to the full batch shape
            # to give the `batch_shape` that defines the shape of the result.

            observation_tensor_shape = tf.shape(input=value)
            observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                               _underlying_event_rank]
            # value :: observation_batch_shape num_steps observation_event_shape
            batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                     self.batch_shape_tensor())
            log_init = tf.broadcast_to(
                self._log_init,
                tf.concat([batch_shape, [self._num_states]], axis=0))
            # log_init :: batch_shape num_states
            log_transition = self._log_trans

            # `observation_event_shape` is the shape of each sequence of observations
            # emitted by the model.
            observation_event_shape = observation_tensor_shape[
                -1 - self._underlying_event_rank:]
            working_obs = tf.broadcast_to(
                value, tf.concat([batch_shape, observation_event_shape],
                                 axis=0))
            # working_obs :: batch_shape observation_event_shape
            r = self._underlying_event_rank

            # Move index into sequence of observations to front so we can apply
            # tf.foldl
            working_obs = distribution_util.move_dimension(
                working_obs, -1 - r, 0)[..., tf.newaxis]
            # working_obs :: num_steps batch_shape underlying_event_shape
            observation_probs = (
                self._observation_distribution.log_prob(working_obs))

            def forward_step(log_prev_step, log_prob_observation):
                return _log_vector_matrix(
                    log_prev_step, log_transition) + log_prob_observation

            fwd_prob = tf.foldl(forward_step,
                                observation_probs,
                                initializer=log_init)
            # fwd_prob :: batch_shape num_states

            log_prob = tf.reduce_logsumexp(input_tensor=fwd_prob, axis=-1)
            # log_prob :: batch_shape

            return log_prob
Example #7
0
    def _sample_n(self, n, seed=None):
        distribution0 = self._get_distribution0()

        if self._num_steps is not None:
            num_steps = tf.convert_to_tensor(self._num_steps)
            num_steps_static = tf.get_static_value(num_steps)
        else:
            num_steps_static = tensorshape_util.num_elements(
                distribution0.event_shape)
            if num_steps_static is None:
                num_steps = tf.reduce_prod(distribution0.event_shape_tensor())

        stateless_seed = samplers.sanitize_seed(seed, salt='Autoregressive')
        stateful_seed = None
        try:
            samples = distribution0.sample(n, seed=stateless_seed)
            is_stateful_sampler = False
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            msg = (
                'Falling back to stateful sampling for `distribution_fn(sample0)` of '
                'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                'This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(distribution0.name, type(distribution0), str(e)))
            stateful_seed = SeedStream(seed, salt='Autoregressive')()
            samples = distribution0.sample(n, seed=stateful_seed)
            is_stateful_sampler = True

        seed = stateful_seed if is_stateful_sampler else stateless_seed

        if num_steps_static is not None:
            for _ in range(num_steps_static):
                # pylint: disable=not-callable
                samples = self.distribution_fn(samples).sample(seed=seed)
        else:
            # pylint: disable=not-callable
            samples = tf.foldl(
                lambda s, _: self.distribution_fn(s).sample(seed=seed),
                elems=tf.range(0, num_steps),
                initializer=samples)
        return samples
  def _log_prob(self, value):
    # The argument `value` is a tensor of sequences of observations.
    # `observation_batch_shape` is the shape of that tensor with the
    # sequence part removed.
    # `observation_batch_shape` is then broadcast to the full batch shape
    # to give the `batch_shape` that defines the shape of the result.
    observation_tensor_shape = ps.shape(value)
    observation_distribution = self.observation_distribution
    underlying_event_rank = ps.size(
        observation_distribution.event_shape_tensor())
    observation_batch_shape = observation_tensor_shape[
        :-1 - underlying_event_rank]
    # value :: observation_batch_shape num_steps observation_event_shape
    batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                             self.batch_shape_tensor())
    num_states = self.transition_distribution.batch_shape_tensor()[-1]
    log_init = _extract_log_probs(num_states,
                                  self.initial_distribution)
    # log_init :: batch_shape num_states
    log_init = tf.broadcast_to(log_init,
                               ps.concat([batch_shape,
                                          [num_states]], axis=0))
    log_transition = _extract_log_probs(num_states,
                                        self.transition_distribution)

    # `observation_event_shape` is the shape of each sequence of observations
    # emitted by the model.
    observation_event_shape = observation_tensor_shape[
        -1 - underlying_event_rank:]
    working_obs = tf.broadcast_to(value,
                                  ps.concat([batch_shape,
                                             observation_event_shape],
                                            axis=0))
    # working_obs :: batch_shape observation_event_shape
    r = underlying_event_rank

    # Move index into sequence of observations to front so we can apply
    # tf.foldl
    if self._time_varying_observation_distribution:
      working_obs = tf.expand_dims(working_obs, -1 - r)
      # working_obs :: batch_shape num_steps 1 underlying_event_shape
      observation_probs = observation_distribution.log_prob(working_obs)
      # observation_probs :: batch_shape num_steps num_states
      observation_probs = distribution_util.move_dimension(
          observation_probs, -2, 0)
      # observation_probs :: num_steps batch_shape num_states
    else:
      working_obs = distribution_util.move_dimension(working_obs, -1 - r, 0)
      # working_obs :: num_steps batch_shape underlying_event_shape
      working_obs = tf.expand_dims(working_obs, -1 - r)
      # working_obs :: num_steps batch_shape 1 underlying_event_shape

      observation_probs = observation_distribution.log_prob(working_obs)
      # observation_probs :: num_steps batch_shape num_states

    def forward_step(log_prev_step, log_prob_observation):
      return _log_vector_matrix(log_prev_step,
                                log_transition) + log_prob_observation

    # TODO(davmre): Delete this warning after Dec 31, 2020.
    warnings.warn(
        'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
        'in which the transition model was applied prior to the initial step. '
        'This bug has been fixed. You may observe a slight change in behavior.')
    fwd_prob = tf.foldl(forward_step, observation_probs[1:],
                        initializer=log_init + observation_probs[0])
    # fwd_prob :: batch_shape num_states

    log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
    # log_prob :: batch_shape

    return log_prob