예제 #1
0
def _random_gamma_bwd(shape, log_space, aux, g):
    """The gradient of the gamma samples."""
    samples, concentration, rate, log_rate = aux
    dsamples, dimpl = g
    # Ignore any gradient contributions that come from the implementation enum.
    del dimpl
    partial_concentration, partial_rate, partial_log_rate = _compute_partials(
        samples, concentration, rate, log_rate, log_space)

    # These will need to be shifted by the extra dimensions added from
    # `sample_shape`.
    rate_shape = _shape_or_scalar(rate, log_rate)
    reduce_dims = tf.range(
        tf.size(shape) -
        tf.maximum(tf.rank(concentration), tf.size(rate_shape)))
    grad_concentration = tf.math.reduce_sum(dsamples * partial_concentration,
                                            axis=reduce_dims)
    grad_log_rate = None
    grad_rate = None
    if rate is not None:
        grad_rate = tf.math.reduce_sum(dsamples * partial_rate,
                                       axis=reduce_dims)
    elif log_rate is not None:
        grad_log_rate = tf.math.reduce_sum(dsamples * partial_log_rate,
                                           axis=reduce_dims)

    rate_tensorshape = _tensorshape_or_scalar(rate, log_rate)
    if (tensorshape_util.is_fully_defined(concentration.shape)
            and tensorshape_util.is_fully_defined(rate_tensorshape)
            and concentration.shape == rate_tensorshape):
        return grad_concentration, grad_rate, grad_log_rate, None  # seed=None

    ax_conc, ax_rate = tf.raw_ops.BroadcastGradientArgs(
        s0=tf.shape(concentration), s1=rate_shape)
    grad_concentration = tf.reshape(
        tf.math.reduce_sum(grad_concentration, axis=ax_conc),
        tf.shape(concentration))
    if grad_rate is not None:
        grad_rate = tf.reshape(tf.math.reduce_sum(grad_rate, axis=ax_rate),
                               rate_shape)
    if grad_log_rate is not None:
        grad_log_rate = tf.reshape(
            tf.math.reduce_sum(grad_log_rate, axis=ax_rate), rate_shape)

    return grad_concentration, grad_rate, grad_log_rate, None  # seed=None
예제 #2
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []
        sample_shape = None  # Memoize concretization.

        # Check valid shape.
        ndims_ = tensorshape_util.rank(self.sample_shape.shape)
        if is_init != (ndims_ is None):
            msg = 'Argument `sample_shape` must be either a scalar or a vector.'
            if ndims_ is not None:
                if ndims_ > 1:
                    raise ValueError(msg)
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_less(tf.rank(sample_shape),
                                            2,
                                            message=msg))

        # Check valid dtype.
        if is_init:  # No xor check because `dtype` cannot change.
            dtype_ = self.sample_shape.dtype
            if dtype_ is None:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                dtype_ = sample_shape.dtype
            if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}:
                raise TypeError(
                    'Argument `sample_shape` must be integer type; '
                    'saw {}.'.format(dtype_util.name(dtype_)))

        # Check valid "value".
        if is_init != tensor_util.is_ref(self.sample_shape):
            sample_shape_ = tf.get_static_value(self.sample_shape)
            msg = 'Argument `sample_shape` must have non-negative values.'
            if sample_shape_ is not None:
                if np.any(np.array(sample_shape_) < 0):
                    raise ValueError('{} Saw: {}'.format(msg, sample_shape_))
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_greater(sample_shape, -1, message=msg))

        return assertions
예제 #3
0
 def __call__(self, cache, new_items):
     datawise_matches = []
     for key in self.keys:
         cache_vals = cache.data[key]
         new_items_vals = new_items[key]
         if cache_vals.dtype.is_floating:
             raise NotImplementedError(
                 'Floating datatypes are not yet implemented.')
         cache_vals = tf.expand_dims(cache_vals, axis=0)
         new_items_vals = tf.expand_dims(new_items_vals, axis=1)
         elementwise = cache_vals == new_items_vals
         datawise = tf.reduce_all(elementwise,
                                  axis=range(2, tf.rank(elementwise)))
         datawise_matches.append(datawise)
     all_keys_datawise = tf.stack(datawise_matches, axis=2)
     all_keys_match = tf.reduce_all(all_keys_datawise, axis=2)
     in_cache = tf.reduce_any(all_keys_match, axis=1)
     return tf.logical_not(in_cache)
예제 #4
0
 def _sample_shape(self, x):
   """Computes graph and static `sample_shape`."""
   x_ndims = (tf.rank(x) if x.shape.ndims is None else x.shape.ndims)
   event_ndims = (
       tf.size(input=self.event_shape_tensor())
       if self.event_shape.ndims is None else self.event_shape.ndims)
   batch_ndims = (
       tf.size(input=self._batch_shape_unexpanded)
       if self.batch_shape.ndims is None else self.batch_shape.ndims)
   sample_ndims = x_ndims - batch_ndims - event_ndims
   if isinstance(sample_ndims, int):
     static_sample_shape = x.shape[:sample_ndims]
   else:
     static_sample_shape = tf.TensorShape(None)
   if static_sample_shape.is_fully_defined():
     sample_shape = np.int32(static_sample_shape.as_list())
   else:
     sample_shape = tf.shape(input=x)[:sample_ndims]
   return sample_shape, static_sample_shape
            def _do_update(x_update_diff_norm_sq, x_update,
                           hess_matmul_x_update):  # pylint: disable=missing-docstring
                hessian_column_with_l2 = sparse_or_dense_matvecmul(
                    hessian_unregularized_loss_outer,
                    hessian_unregularized_loss_middle *
                    _sparse_or_dense_matmul_onehot(
                        hessian_unregularized_loss_outer, coord),
                    adjoint_a=True)

                if l2_regularizer is not None:
                    hessian_column_with_l2 += _one_hot_like(
                        hessian_column_with_l2,
                        coord,
                        on_value=2. * l2_regularizer)

                # Move the batch dimensions of `hessian_column_with_l2` to rightmost in
                # order to conform to `hess_matmul_x_update`.
                n = tf.rank(hessian_column_with_l2)
                perm = tf.roll(tf.range(n), shift=1, axis=0)
                hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2,
                                                      perm=perm)

                # Update the entire batch at `coord` even if `delta` may be 0 at some
                # batch coordinates. In those cases, adding `delta` is a no-op.
                x_update = tf.tensor_scatter_nd_add(x_update, [[coord]],
                                                    [delta])

                with tf.control_dependencies([x_update]):
                    x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2
                    hess_matmul_x_update_ = (hess_matmul_x_update +
                                             delta * hessian_column_with_l2)

                    # Hint that loop vars retain the same shape.
                    x_update_diff_norm_sq_.set_shape(
                        x_update_diff_norm_sq_.shape.merge_with(
                            x_update_diff_norm_sq.shape))
                    hess_matmul_x_update_.set_shape(
                        hess_matmul_x_update_.shape.merge_with(
                            hess_matmul_x_update.shape))

                    return [
                        x_update_diff_norm_sq_, x_update, hess_matmul_x_update_
                    ]
예제 #6
0
def _expand_right(a, n, pos):
  """Insert multiple dimensions of size 1 at position `pos` in tensor's shape.

  Equivalent to performing `expand_dims(..., pos)` `n` times.

  Args:
    a: tensor into which extra dimensions will be inserted.
    n: number of inserted dimensions.
    pos: choice of dimension for insertion. Must be negative.

  Returns:
    Tensor with inserted dimensions.
  """

  axis = tf.rank(a) + pos + 1
  return tf.reshape(a, tf.concat([
      tf.shape(a)[:axis],
      tf.ones([n], dtype=tf.int32),
      tf.shape(a)[axis:]], axis=0))
예제 #7
0
def _maybe_check_valid_shape(shape, validate_args):
    """Check that a shape Tensor is int-type and otherwise sane."""
    if not dtype_util.is_integer(shape.dtype):
        raise TypeError('`{}` dtype (`{}`) should be `int`-like.'.format(
            shape, dtype_util.name(shape.dtype)))

    assertions = []

    message = '`{}` rank should be <= 1.'
    if tensorshape_util.rank(shape.shape) is not None:
        if tensorshape_util.rank(shape.shape) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.rank(shape),
                                    2,
                                    message=message.format(shape)))

    shape_ = tf.get_static_value(shape)

    message = '`{}` elements must have at most one `-1`.'
    if shape_ is not None:
        if sum(shape_ == -1) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.reduce_sum(
                tf.cast(tf.equal(shape, -1), tf.int32)),
                                    2,
                                    message=message.format(shape)))

    message = '`{}` elements must be either positive integers or `-1`.'
    if shape_ is not None:
        if np.any(shape_ < -1):
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(shape,
                                       -2,
                                       message=message.format(shape)))

    return assertions
예제 #8
0
def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
    a = array_ops.asarray(a).data

    if offset == 0:
        a_shape = a.shape
        if a_shape.rank is not None:
            rank = len(a_shape)
            if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1
                                                       or axis2 == rank - 1):
                return utils.tensor_to_ndarray(tf.linalg.trace(a))

    a_rank = tf.rank(a)
    if axis1 < 0:
        axis1 += a_rank
    if axis2 < 0:
        axis2 += a_rank

    minaxis = tf.minimum(axis1, axis2)
    maxaxis = tf.maximum(axis1, axis2)

    # Move axes of interest to the end.
    range_rank = tf.range(a_rank)
    perm = tf.concat([
        range_rank[0:minaxis], range_rank[minaxis + 1:maxaxis],
        range_rank[maxaxis + 1:], [axis1, axis2]
    ],
                     axis=0)
    a = tf.transpose(a, perm)

    a_shape = tf.shape(a)

    # All zeros since diag_part doesn't handle all possible k (aka offset).
    # Written this way since cond will run shape inference on both branches,
    # and diag_part shape inference will fail when offset is out of bounds.
    a, offset = utils.cond(
        utils.logical_or(
            utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)),
            utils.greater_equal(offset, utils.getitem(a_shape, -1)),
        ), lambda: (tf.zeros_like(a), 0), lambda: (a, offset))

    a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset))
    return array_ops.sum(a, -1, dtype)
예제 #9
0
 def _sample_shape(self, x, event_shape, event_shape_tensor):
     """Computes graph and static `sample_shape`."""
     x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else
                tensorshape_util.rank(x.shape))
     event_ndims = (tf.size(event_shape_tensor)
                    if tensorshape_util.rank(event_shape) is None else
                    tensorshape_util.rank(event_shape))
     batch_ndims = (tf.size(self._batch_shape_unexpanded)
                    if tensorshape_util.rank(self.batch_shape) is None else
                    tensorshape_util.rank(self.batch_shape))
     sample_ndims = x_ndims - batch_ndims - event_ndims
     if isinstance(sample_ndims, int):
         static_sample_shape = x.shape[:sample_ndims]
     else:
         static_sample_shape = tf.TensorShape(None)
     if tensorshape_util.is_fully_defined(static_sample_shape):
         sample_shape = np.int32(static_sample_shape)
     else:
         sample_shape = tf.shape(x)[:sample_ndims]
     return sample_shape, static_sample_shape
예제 #10
0
    def train_loss(self, env_step, rewards, next_env_step, policy, gamma):
        values = self._get_value(env_step)
        discounts = gamma * next_env_step.discount
        target_values = self._get_target_value(next_env_step, policy)
        #target_values = tf.reduce_min(target_values, axis=-1, keepdims=True)

        if self._num_qvalues is not None and tf.rank(discounts) == 1:
            discounts = discounts[:, None]
        td_targets = rewards + discounts * tf.stop_gradient(target_values)

        policy_ratio = 1.0
        if not self._solve_for_state_action_value:
            tfagents_step = dataset_lib.convert_to_tfagents_timestep(env_step)
            policy_log_probabilities = policy.distribution(
                tfagents_step).action.log_prob(env_step.action)
            policy_ratio = tf.exp(policy_log_probabilities -
                                  env_step.get_log_probability())

        td_errors = policy_ratio * td_targets - values
        return tf.square(td_errors)
예제 #11
0
 def test_shape_changing_bijector(self):
     num_tril_nonzero = lambda num_rows: num_rows * (num_rows + 1) // 2
     num_tril_rows = lambda nnz: (  # pylint: disable=g-long-lambda
         np.sqrt(0.25 + 2. * nnz) - 0.5).astype(np.int32)
     pad_eye = tfb.Inline(
         forward_fn=lambda x: tf.concat(
             [  # pylint: disable=g-long-lambda
                 tfb.FillScaleTriL().inverse(
                     tf.eye(num_tril_rows(
                         tf.compat.dimension_value(x.shape[-1])),
                            batch_shape=tf.shape(x)[:-2]))[..., tf.
                                                           newaxis, :],
                 x,
             ],
             axis=tf.rank(x) - 2),
         inverse_fn=lambda y: y[..., 1:, :],
         inverse_log_det_jacobian_fn=lambda y, event_ndims: 0.,
         forward_event_shape_fn=lambda in_shape: in_shape + tf.one_hot(  # pylint: disable=g-long-lambda
             tf.size(in_shape) - 2,
             depth=tf.size(in_shape),
             dtype=tf.int32),
         inverse_event_shape_fn=lambda out_shape: out_shape - tf.one_hot(  # pylint: disable=g-long-lambda
             tf.size(out_shape) - 2,
             depth=tf.size(out_shape),
             dtype=tf.int32),
         forward_min_event_ndims=2,
         inverse_min_event_ndims=2,
         is_constant_jacobian=True,
         name='PadEyeBijector')
     scale_tril = tfp.util.TransformedVariable(
         tf.eye(3, batch_shape=[5, 1, 4]),
         bijector=tfb.Chain([tfb.FillScaleTriL(), pad_eye]))
     self.assertAllEqual((5, 1, 4, 3, 3), scale_tril.shape)
     self.assertAllEqual((5, 1, 4 - 1, num_tril_nonzero(3)),
                         scale_tril.pretransformed_input.shape)
     self.evaluate([v.initializer for v in scale_tril.trainable_variables])
     shape_, scale_tril_ = self.evaluate(
         [tf.shape(scale_tril),
          tf.convert_to_tensor(scale_tril)])
     self.assertAllEqual((5, 1, 4, 3, 3), shape_)
     self.assertAllEqual((5, 1, 4, 3, 3), scale_tril_.shape)
예제 #12
0
    def _forward_log_det_jacobian(self, x, **kwargs):
        x = tf.convert_to_tensor(value=x, name="x")

        fldj = tf.cast(0., dtype=dtype_util.base_dtype(x.dtype))

        if not self.bijectors:
            return fldj

        event_ndims = self._maybe_get_static_event_ndims(
            self.forward_min_event_ndims)

        if _use_static_shape(x, event_ndims):
            event_shape = x.shape[tensorshape_util.rank(x.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(input=x)[tf.rank(x) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in reversed(self.bijectors):
            fldj += b.forward_log_det_jacobian(x,
                                               event_ndims=event_ndims,
                                               **kwargs.get(b.name, {}))
            if _use_static_shape(x, event_ndims):
                event_shape = b.forward_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.forward_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(input=event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            x = b.forward(x, **kwargs.get(b.name, {}))

        return fldj
예제 #13
0
    def policy_fn(observation,
                  probability_table=probability_table,
                  obs_to_index_fn=obs_to_index_fn,
                  return_distribution=return_distribution,
                  dtype=tf.int32):
        state = obs_to_index_fn(observation)
        distribution = tf.gather(probability_table, state)
        batched = tf.rank(distribution) > 1
        if not batched:
            distributions = distribution[None, :]
        else:
            distributions = distribution

        batch_size = tf.shape(distributions)[0]

        actions = tf.random.categorical(tf.math.log(1e-8 + distributions),
                                        1,
                                        dtype=dtype)
        actions = tf.squeeze(actions, -1)
        probs = tf.gather_nd(
            distributions,
            tf.stack([tf.range(batch_size, dtype=dtype), actions], -1))

        if not batched:
            action = actions[0]
            log_prob = tf.math.log(1e-8 + probs[0])
        else:
            action = actions
            log_prob = tf.math.log(1e-8 + probs)

        if return_distribution:
            policy_info = {'distribution': distribution}
            return (tfp.distributions.Categorical(probs=distribution,
                                                  dtype=dtype), policy_info)
        else:
            policy_info = {
                'log_probability': log_prob,
                'distribution': distribution
            }
            return action, policy_info
예제 #14
0
    def _inverse_log_det_jacobian(self, y, **kwargs):
        y = tf.convert_to_tensor(value=y, name="y")
        ildj = tf.cast(0., dtype=dtype_util.base_dtype(y.dtype))

        if not self.bijectors:
            return ildj

        event_ndims = self._maybe_get_static_event_ndims(
            self.inverse_min_event_ndims)

        if _use_static_shape(y, event_ndims):
            event_shape = y.shape[tensorshape_util.rank(y.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(input=y)[tf.rank(y) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in self.bijectors:
            ildj += b.inverse_log_det_jacobian(y,
                                               event_ndims=event_ndims,
                                               **kwargs.get(b.name, {}))

            if _use_static_shape(y, event_ndims):
                event_shape = b.inverse_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.inverse_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(input=event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            y = b.inverse(y, **kwargs.get(b.name, {}))
        return ildj
def outer_multiply(x, y):
    """Performs an outer multiplication of two tensors.

  Given two `Tensor`s, `S` and `T` of shape `s` and `t` respectively, the outer
  product `P` is a `Tensor` of shape `s + t` whose components are given by:

  ```none
  P_{i1,...ik, j1, ... , jm} = S_{i1...ik} T_{j1, ... jm}
  ```

  Args:
    x: A `Tensor` of any shape and numeric dtype.
    y: A `Tensor` of any shape and the same dtype as `x`.

  Returns:
    outer_product: A `Tensor` of shape Shape[x] + Shape[y] and the same dtype
      as `x`.
  """
    x_shape = tf.shape(x)
    padded_shape = tf.concat(
        [x_shape, tf.ones(tf.rank(y), dtype=x_shape.dtype)], axis=0)
    return tf.reshape(x, padded_shape) * y
예제 #16
0
파일: math.py 프로젝트: jaeyounkim/trax
def argsort(a, axis=-1, kind='quicksort', order=None):  # pylint: disable=missing-docstring
  # TODO(nareshmodi): make string tensors also work.
  if kind not in ('quicksort', 'stable'):
    raise ValueError("Only 'quicksort' and 'stable' arguments are supported.")
  if order is not None:
    raise ValueError("'order' argument to sort is not supported.")
  stable = (kind == 'stable')

  a = array_creation.asarray(a).data

  def _argsort(a, axis, stable):
    if axis is None:
      a = tf.reshape(a, [-1])
      axis = 0

    return tf.argsort(a, axis, stable=stable)

  tf_ans = tf.cond(
      tf.rank(a) == 0, lambda: tf.constant([0]),
      lambda: _argsort(a, axis, stable))

  return array_creation.asarray(tf_ans, dtype=np.intp)
예제 #17
0
    def test_poisson_switchover_graphical_model(self):
        # Build a pretend dataset.
        seed = test_util.test_seed_stream(salt='poisson')
        n = [43, 31]
        count_data = tf.cast(tf.concat([
            tfd.Poisson(rate=15.).sample(n[0], seed=seed()),
            tfd.Poisson(rate=25.).sample(n[1], seed=seed()),
        ],
                                       axis=0),
                             dtype=tf.float32)
        count_data = self.evaluate(count_data)
        n = np.sum(n)

        # Make model.
        gather = lambda tau, lambda_: tf.gather(  # pylint: disable=g-long-lambda
            lambda_,
            indices=tf.cast(tau[..., tf.newaxis] < tf.linspace(0., 1., n),
                            dtype=tf.int32),
            # TODO(b/139204153): Remove static value hack after bug closed.
            batch_dims=int(tf.get_static_value(tf.rank(tau))))

        alpha = tf.math.reciprocal(tf.reduce_mean(count_data))

        joint = tfd.JointDistributionSequential(
            [
                tfd.Sample(tfd.Exponential(rate=alpha), sample_shape=[2]),
                tfd.Uniform(),
                lambda tau, lambda_: tfd.Independent(  # pylint: disable=g-long-lambda
                    tfd.Poisson(rate=gather(tau, lambda_)),
                    reinterpreted_batch_ndims=1),
            ],
            validate_args=True)

        # Verify model correctly "compiles".
        batch_shape = [3, 4]
        self.assertEqual(
            batch_shape,
            joint.log_prob(
                joint.sample(batch_shape, seed=test_util.test_seed())).shape)
예제 #18
0
def _sample_bates(total_count, low, high, n, seed=None):
    """Vectorized production of `Bates` samples.

  Args:
    total_count: (Batches of) counts of `Uniform`s to take means of.  Should
      have integer dtype and already be broadcasted to the batch shape.
    low: (Batches of) lower bounds of the `Uniform` variables to sample.  Should
      be the same floating dtype as `high` and broadcastable to the batch shape.
    high: (Batches of) upper bounds of the `Uniform` variables to sample. Should
      be the same floating dtype as `low` and broadcastable to the batch shape.
    n: `int32` number of samples to generate.
    seed: Random seed to pass to `Uniform` sampler.

  Returns:
    samples: Samples of (batches of) the `Bates` variable.  Will have same dtype
      as `low` and `high`. If the batch shape is `[B1,..., Bn]`, `samples` has
      shape `[n, B1,..., Bn]`.
  """

    # 1. Sample Uniform(0, 1)s, flattening the batch dimension into axis 0.
    uniform_sample_shape = ps.concat([[ps.reduce_sum(total_count)], [n]],
                                     axis=0)
    uniform_samples = samplers.uniform(uniform_sample_shape,
                                       minval=0.,
                                       maxval=1.,
                                       dtype=low.dtype,
                                       seed=seed)
    # 2. Produce segment means.
    segment_lengths = tf.reshape(total_count, [-1])
    segment_ids = tf.repeat(tf.range(tf.size(segment_lengths)),
                            segment_lengths)
    flatmeans = tf.math.segment_mean(uniform_samples, segment_ids)
    # 3. Reshape and transpose segment means back to the original shape.
    outshape = tf.concat([tf.shape(total_count), [n]], axis=0)
    tmeans = tf.reshape(flatmeans, outshape)
    axes = tf.range(tf.rank(tmeans))
    means = tf.transpose(tmeans, tf.roll(axes, shift=1, axis=0))
    # 4. Shift/scale from (0, 1) to (low, high).
    return low + (high - low) * means
예제 #19
0
def assert_rank_at_most(x,
                        rank,
                        data=None,
                        summarize=None,
                        message=None,
                        name=None):
    """Assert `x` has rank equal to `rank` or smaller.

  Example of adding a dependency to an operation:

  ```python
  with tf.control_dependencies([tf.assert_rank_at_most(x, 2)]):
    output = tf.reduce_sum(x)
  ```

  Args:
    x:  Numeric `Tensor`.
    rank:  Scalar `Tensor`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).
      Defaults to "assert_rank_at_most".

  Returns:
    Op raising `InvalidArgumentError` unless `x` has specified rank or lower.
    If static checks determine `x` has correct rank, a `no_op` is returned.

  Raises:
    ValueError:  If static checks determine `x` has wrong rank.
  """
    with tf.name_scope(name or 'assert_rank_at_most'):
        return tf1.assert_less_equal(tf.rank(x),
                                     rank,
                                     data=data,
                                     summarize=summarize,
                                     message=message)
예제 #20
0
            def sample_next_angle(seed, angle, angle_min, angle_max,
                                  current_state_parts, current_log_likelihood):
                """Slice sample a new angle, and rotate init_state by that amount."""
                angle_seed, next_seed = samplers.split_seed(seed)
                chain_not_done = current_log_likelihood < threshold
                # Box in on angle. Only update angles for which we haven't generated a
                # point that beats the threshold.
                angle_min = tf.where((angle < 0) & chain_not_done, angle,
                                     angle_min)
                angle_max = tf.where((angle >= 0) & chain_not_done, angle,
                                     angle_max)
                new_angle = samplers.uniform(
                    shape=tf.shape(current_log_likelihood),
                    minval=angle_min,
                    maxval=angle_max,
                    seed=angle_seed,
                    dtype=angle.dtype.base_dtype)
                angle = tf.where(chain_not_done, new_angle, angle)
                next_state_parts = _rotate_on_ellipse(init_state_parts,
                                                      normal_samples, angle)

                new_state_parts = []
                broadcasted_chain_not_done = _right_pad_with_ones(
                    chain_not_done, tf.rank(next_state_parts[0]))
                for n_state, c_state in zip(next_state_parts,
                                            current_state_parts):
                    new_state_part = tf.where(broadcasted_chain_not_done,
                                              n_state, c_state)
                    new_state_parts.append(new_state_part)

                return (
                    next_seed,
                    angle,
                    angle_min,
                    angle_max,
                    new_state_parts,
                    self.log_likelihood_fn(*new_state_parts)  # pylint: disable=not-callable
                )
예제 #21
0
    def reward_fn(env_step, valid_steps, qvalues=self._point_qvalues):
      """Computes average initial Q-values of episodes."""
      # env_step is an episode, and we just want the first step.
      if tf.rank(valid_steps) == 1:
        first_step = tf.nest.map_structure(lambda t: t[0, ...], env_step)
      else:
        first_step = tf.nest.map_structure(lambda t: t[:, 0, ...], env_step)

      if self._solve_for_state_action_value:
        indices = self._get_index(first_step.observation[:, None],
                                  np.arange(self._num_actions)[None, :])
        initial_qvalues = tf.cast(tf.gather(qvalues, indices), tf.float32)

        tfagents_first_step = dataset_lib.convert_to_tfagents_timestep(
            first_step)
        initial_target_probs = target_policy.distribution(
            tfagents_first_step).action.probs_parameter()
        value = tf.reduce_sum(initial_qvalues * initial_target_probs, axis=-1)
      else:
        indices = self._get_index(first_step.observation, first_step.action)
        value = tf.cast(tf.gather(qvalues, indices), tf.float32)

      return value
예제 #22
0
  def _maybe_warn_increased_dof(self,
                                component_name,
                                component_ldj,
                                increased_dof):
    """Warns or raises when `increased_dof` is True."""
    # Short-circuit when the component LDJ is statically zero.
    if (tf.get_static_value(tf.rank(component_ldj)) == 0
        and tf.get_static_value(component_ldj) == 0):
      return

    # Short-circuit when increased_dof is statically False.
    increased_dof_ = tf.get_static_value(increased_dof)
    if increased_dof_ is False:  # pylint: disable=g-bool-id-comparison
      return

    error_message = (
        'Nested component "{}" in composition "{}" operates on inputs '
        'with increased degrees of freedom. This may result in an '
        'incorrect log_det_jacobian.'
        ).format(component_name, self.name)

    # When validate_args is True, we raise on increased DoF.
    if self._validate_args:
      if increased_dof_:
        raise ValueError(error_message)
      return assert_util.assert_equal(False, increased_dof, error_message)

    if (not tf.executing_eagerly() and
        control_flow_util.GraphOrParentsInXlaContext(tf1.get_default_graph())):
      return  # No StringFormat or Print ops in XLA.

    # Otherwise, we print a warning and continue.
    return ps.cond(
        pred=increased_dof,
        false_fn=tf.no_op,
        true_fn=lambda: tf.print(  # pylint: disable=g-long-lambda
            'WARNING: ' + error_message, output_stream=sys.stderr))
예제 #23
0
    def _integrator_conserves_energy(self, x, independent_chain_ndims, seed):
        event_dims = tf.range(independent_chain_ndims, tf.rank(x))

        target_fn = lambda x: self._log_gamma_log_prob(x, event_dims)

        m = tf.random.normal(tf.shape(x), seed=seed)
        log_prob_0 = target_fn(x)
        old_energy = -log_prob_0 + 0.5 * tf.reduce_sum(m**2., axis=event_dims)

        event_size = np.prod(self.evaluate(x).shape[independent_chain_ndims:])

        integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
            target_fn, step_sizes=[0.09 / event_size], num_steps=1000)

        [[new_m], [_], log_prob_1, [_]] = integrator([m], [x])

        new_energy = -log_prob_1 + 0.5 * tf.reduce_sum(new_m**2.,
                                                       axis=event_dims)

        old_energy_, new_energy_ = self.evaluate([old_energy, new_energy])
        tf1.logging.vlog(
            1, 'average energy relative change: {}'.format(
                (1. - new_energy_ / old_energy_).mean()))
        self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
예제 #24
0
def _backward_matmul_one_part(dcovx,
                              kernel_fn,
                              kernel_args,
                              x1,
                              x2,
                              x,
                              part_size,
                              part_index,
                              remainder_part_size=None):
    """Applies a single chunk of backprop split along the axis defined by `x1`."""
    # Assume `cov = kernel.matrix(x1, x2)` has shape (A,B), and `x` has shape
    # (B,C). Then `cov @ x` had shape (A,C), as will `dcovx`. Below, we refer to
    # the "part" size as P.
    dcovx_ax = tf.rank(dcovx) - 2
    dcovx_ax_selector = tf.equal(tf.range(tf.rank(dcovx)), dcovx_ax)
    begin_dcovx = tf.where(dcovx_ax_selector, part_size * part_index, 0)
    size_dcovx = tf.where(
        dcovx_ax_selector,
        part_size if remainder_part_size is None else remainder_part_size,
        tf.shape(dcovx))
    dcovxpart = _slice(dcovx, begin_dcovx, size_dcovx, dcovx_ax)  # PxC
    dcovpart = tf.matmul(dcovxpart, x, transpose_b=True)  # PxC @ (BxC).T = PxB
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        # Here begins the "recomputed" part of the gradient.
        tape.watch((x1, x2, kernel_args))
        kernel = kernel_fn(*kernel_args)
        x1_ax = tf.rank(x1) - kernel.feature_ndims - 1
        x1_ax_selector = tf.equal(tf.range(tf.rank(x1)), x1_ax)
        begin_x1 = tf.where(x1_ax_selector, part_size * part_index, 0)
        size_x1 = tf.where(
            x1_ax_selector,
            part_size if remainder_part_size is None else remainder_part_size,
            tf.shape(x1))
        covpart = kernel.matrix(_slice(x1, begin_x1, size_x1, x1_ax),
                                x2)  # PxB
    dx1part, dx2part, dkernel_args = tape.gradient(covpart,
                                                   (x1, x2, kernel_args),
                                                   output_gradients=dcovpart)
    dxpart = tf.matmul(covpart, dcovxpart, transpose_a=True)  # (PxB).T @ PxC
    dxpart = tf.reduce_sum(dxpart, axis=tf.range(tf.rank(dxpart) - tf.rank(x)))
    return dx1part, dx2part, dxpart, dkernel_args
예제 #25
0
def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
    """Squeeze or expand last dimension if needed.

  1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
  (using `remove_squeezable_dimensions`).
  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
  from the new rank of `y_pred`.
  If `sample_weight` is scalar, it is kept scalar.

  This will use static shape if available. Otherwise, it will add graph
  operations, which could result in a performance hit.

  Args:
    y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
    y_true: Optional label `Tensor` whose dimensions match `y_pred`.
    sample_weight: Optional weight scalar or `Tensor` whose dimensions match
      `y_pred`.

  Returns:
    Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
    the last dimension squeezed,
    `sample_weight` could be extended by one dimension.
    If `sample_weight` is None, (y_pred, y_true) is returned.
  """
    y_pred_shape = y_pred.shape
    y_pred_rank = y_pred_shape.ndims
    if y_true is not None:

        # If sparse matrix is provided as `y_true`, the last dimension in `y_pred`
        # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)),
        # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3))
        # In this case, we should not try to remove squeezable dimension.
        y_true_shape = y_true.shape
        y_true_rank = y_true_shape.ndims
        if (y_true_rank is not None) and (y_pred_rank is not None):
            # Use static rank for `y_true` and `y_pred`.
            if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1:
                y_true, y_pred = remove_squeezable_dimensions(y_true, y_pred)
        else:
            # Use dynamic rank.
            rank_diff = tf.rank(y_pred) - tf.rank(y_true)
            squeeze_dims = lambda: remove_squeezable_dimensions(  # pylint: disable=g-long-lambda
                y_true, y_pred)
            is_last_dim_1 = tf.equal(1, tf.shape(y_pred)[-1])
            maybe_squeeze_dims = lambda: tf.cond(  # pylint: disable=g-long-lambda
                is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred))
            y_true, y_pred = tf.cond(tf.equal(1, rank_diff),
                                     maybe_squeeze_dims, squeeze_dims)

    if sample_weight is None:
        return y_pred, y_true

    weights_shape = sample_weight.shape
    weights_rank = weights_shape.ndims
    if weights_rank == 0:  # If weights is scalar, do nothing.
        return y_pred, y_true, sample_weight

    if (y_pred_rank is not None) and (weights_rank is not None):
        # Use static rank.
        if weights_rank - y_pred_rank == 1:
            sample_weight = tf.squeeze(sample_weight, [-1])
        elif y_pred_rank - weights_rank == 1:
            sample_weight = tf.expand_dims(sample_weight, [-1])
        return y_pred, y_true, sample_weight

    # Use dynamic rank.
    weights_rank_tensor = tf.rank(sample_weight)
    rank_diff = weights_rank_tensor - tf.rank(y_pred)
    maybe_squeeze_weights = lambda: tf.squeeze(sample_weight, [-1])

    def _maybe_expand_weights():
        expand_weights = lambda: tf.expand_dims(sample_weight, [-1])
        return tf.cond(tf.equal(rank_diff, -1), expand_weights,
                       lambda: sample_weight)

    def _maybe_adjust_weights():
        return tf.cond(tf.equal(rank_diff, 1), maybe_squeeze_weights,
                       _maybe_expand_weights)

    # squeeze or expand last dim of `sample_weight` if its rank differs by 1
    # from the new rank of `y_pred`.
    sample_weight = tf.cond(tf.equal(weights_rank_tensor, 0),
                            lambda: sample_weight, _maybe_adjust_weights)
    return y_pred, y_true, sample_weight
예제 #26
0
def _upsample(x, up_sz, f, direction, shift):
    """Upsample by a factor of 2 using transposed reflecting boundary conditions.

  This function undecimates `x` along the axis specified by `direction` and then
  convolves it with filter `f`, thereby upsampling it to have a size of `up_sz`.
  This function is a bit awkward, as it's written to be the transpose of
  _downsample(), which uses reflecting boundary conditions. As such, this
  function approximates *the transpose of reflecting boundary conditions*, which
  is not the same as reflecting boundary conditions.
  TODO(barron): Write out the true transpose of reflecting boundary conditions.

  Args:
    x: The input tensor (numpy or TF), of size (num_channels, width, height).
    up_sz: A tuple of ints of size (upsampled_width, upsampled_height). Care
      should be taken by the caller to match the upsampled_width/height with the
      input width/height along the axis that isn't being upsampled.
    f: The input filter, which must be an odd-length 1D numpy array.
    direction: The spatial direction in [0, 1] along which `x` will be convolved
      with `f` after being undecimated. Because `x` has a batch/channels
      dimension, `direction` == 0 corresponds to downsampling along axis 1 in
      `x`, and `direction` == 1 corresponds to downsampling along axis 2 in `x`.
    shift: A shift amount in [0, 1] by which `x` will be shifted along the axis
      specified by `direction` after undecimating.

  Returns:
    `x` undecimated and convolved with `f` along the spatial dimension
    `direction` with transposed reflection boundary conditions with an offset of
    `shift`, to match size `up_sz`.
  """
    _check_resample_inputs(x, f, direction, shift)
    assert_ops = [tf.Assert(tf.equal(tf.rank(f), 1), [tf.rank(f)])]
    with tf.control_dependencies(assert_ops):
        # Undecimate `x` by a factor of 2 along `direction`, by stacking it with
        # and tensor of all zeros along the right axis and then reshaping it such
        # that the zeros are interleaved.
        if direction == 0:
            sz_ex = tf.shape(x) * [1, 2, 1]
        elif direction == 1:
            sz_ex = tf.shape(x) * [1, 1, 2]
        if shift == 0:
            x_and_zeros = [x, tf.zeros_like(x)]
        elif shift == 1:
            x_and_zeros = [tf.zeros_like(x), x]
        x_undecimated = tf.reshape(tf.stack(x_and_zeros, direction + 2), sz_ex)
        # Ensure that `x_undecimated` has a size of `up_sz`, by slicing and padding
        # as needed.
        x_undecimated = x_undecimated[:, 0:up_sz[0], 0:up_sz[1]]
        x_undecimated = tf.pad(
            x_undecimated, [[0, 0], [0, up_sz[0] - tf.shape(x_undecimated)[1]],
                            [0, up_sz[1] - tf.shape(x_undecimated)[2]]])

        # Pad `x_undecimated` with reflection boundary conditions.
        x_padded = pad_reflecting(x_undecimated,
                                  len(f) // 2, (len(f) - 1) // 2,
                                  direction + 1)
        # Convolved x_undecimated with a flipped version of f.
        f_ex = tf.expand_dims(f[::-1], 1 - direction)
        y = tf.nn.conv2d(x_padded[:, :, :, tf.newaxis],
                         tf.cast(f_ex, x.dtype)[:, :, tf.newaxis, tf.newaxis],
                         [1, 1, 1, 1], 'VALID')[:, :, :, 0]
        return y
예제 #27
0
def _rank(input, name=None):  # pylint: disable=redefined-builtin,unused-argument
  if not hasattr(input, 'shape'):
    input = (tf.convert_to_tensor(input) if tf.get_static_value(input) is None
             else np.array(input))
  ndims_ = tensorshape_util.rank(getattr(input, 'shape', None))
  return tf.rank(input) if ndims_ is None else np.int32(ndims_)
예제 #28
0
def _replicate(n, tensor):
    """Replicate the input tensor n times along a new (major) dimension."""
    # TODO(axch) Does this already exist somewhere?  Should it get contributed?
    multiples = tf.concat([[n], tf.ones([tf.rank(tensor)], dtype=n.dtype)],
                          axis=0)
    return tf.tile(tensor[tf.newaxis], multiples)
예제 #29
0
    def _observation_log_probs(self, observations, mask):
        """Compute and shape tensor of log probs associated with observations.."""

        # Let E be the underlying event shape
        #     M the number of steps in the HMM
        #     N the number of states of the HMM
        #
        # Then the incoming observations have shape
        #
        # observations : batch_o [M] E
        #
        # and the mask (if present) has shape
        #
        # mask : batch_m [M]
        #
        # Let this HMM distribution have batch shape batch_d
        # We need to broadcast all three of these batch shapes together
        # into the shape batch.
        #
        # We need to move the step dimension to the first dimension to make
        # them suitable for folding or scanning over.
        #
        # When we call `log_prob` for our observations we need to
        # do this for each state the observation could correspond to.
        # We do this by expanding the dimensions by 1 so we end up with:
        #
        # observations : [M] batch [1] [E]
        #
        # After calling `log_prob` we get
        #
        # observation_log_probs : [M] batch [N]
        #
        # We wish to use `mask` to select from this so we also
        # reshape and broadcast it up to shape
        #
        # mask : [M] batch [N]

        observation_tensor_shape = tf.shape(observations)
        observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                           _underlying_event_rank]
        observation_event_shape = observation_tensor_shape[
            -1 - self._underlying_event_rank:]

        if mask is not None:
            mask_tensor_shape = tf.shape(mask)
            mask_batch_shape = mask_tensor_shape[:-1]

        batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())

        if mask is not None:
            batch_shape = tf.broadcast_dynamic_shape(batch_shape,
                                                     mask_batch_shape)
        observations = tf.broadcast_to(
            observations,
            tf.concat([batch_shape, observation_event_shape], axis=0))
        observation_rank = tf.rank(observations)
        underlying_event_rank = self._underlying_event_rank
        observations = distribution_util.move_dimension(
            observations, observation_rank - underlying_event_rank - 1, 0)
        observations = tf.expand_dims(observations,
                                      observation_rank - underlying_event_rank)
        observation_log_probs = self._observation_distribution.log_prob(
            observations)

        if mask is not None:
            mask = tf.broadcast_to(
                mask, tf.concat([batch_shape, [self._num_steps]], axis=0))
            mask = distribution_util.move_dimension(mask, -1, 0)
            observation_log_probs = tf.where(
                mask[..., tf.newaxis], tf.zeros_like(observation_log_probs),
                observation_log_probs)

        return observation_log_probs
예제 #30
0
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (tensorshape_util.dims(transition_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (tensorshape_util.dims(observation_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (tensorshape_util.dims(transition_distribution.batch_shape)
                     is not None and tensorshape_util.as_list(
                         transition_distribution.batch_shape)[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (tensorshape_util.dims(
                        observation_distribution.batch_shape) is not None
                     and tensorshape_util.as_list(
                         observation_distribution.batch_shape)[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters