コード例 #1
0
ファイル: linalg.py プロジェクト: HackerShohag/SuggestBot-bn
def lu_reconstruct_assertions(lower_upper, perm, validate_args):
    """Returns list of assertions related to `lu_reconstruct` assumptions."""
    assertions = []

    message = 'Input `lower_upper` must have at least 2 dimensions.'
    if tensorshape_util.rank(lower_upper.shape) is not None:
        if tensorshape_util.rank(lower_upper.shape) < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(lower_upper,
                                             rank=2,
                                             message=message))

    message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
    if (tensorshape_util.rank(lower_upper.shape) is not None
            and tensorshape_util.rank(perm.shape) is not None):
        if (tensorshape_util.rank(lower_upper.shape) !=
                tensorshape_util.rank(perm.shape) + 1):
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(lower_upper,
                                    rank=tf.rank(perm) + 1,
                                    message=message))

    message = '`lower_upper` must be square.'
    if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]):
        if lower_upper.shape[-2] != lower_upper.shape[-1]:
            raise ValueError(message)
    elif validate_args:
        m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2)
        assertions.append(assert_util.assert_equal(m, n, message=message))

    return assertions
コード例 #2
0
 def _inverse_log_det_jacobian(self, y):
   split_y = tf.split(y, self.block_sizes, axis=-1, num=len(self.bijectors))
   ildjs = [
       b.inverse_log_det_jacobian(y_, event_ndims=1)
       for b, y_ in zip(self.bijectors, split_y)
   ]
   return sum(ildjs)
コード例 #3
0
 def _forward_log_det_jacobian(self, x):
   split_x = tf.split(x, self.block_sizes, axis=-1, num=len(self.bijectors))
   fldjs = [
       b.forward_log_det_jacobian(x_, event_ndims=1)
       for b, x_ in zip(self.bijectors, split_x)
   ]
   return sum(fldjs)
コード例 #4
0
    def _log_prob(self, x):
        assertions = []
        message = 'Input must have at least one dimension.'
        if tensorshape_util.rank(x.shape) is not None:
            if tensorshape_util.rank(x.shape) == 0:
                raise ValueError(message)
        elif self.validate_args:
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=message))
        with tf.control_dependencies(assertions):
            event_tensors = self._distribution.event_shape_tensor()
            splits = [
                ps.maximum(1, ps.reduce_prod(s))
                for s in tf.nest.flatten(event_tensors)
            ]
            x = tf.nest.pack_sequence_as(event_tensors,
                                         tf.split(x, splits, axis=-1))

            def _reshape_part(part, dtype, event_shape):
                part = tf.cast(part, dtype)
                static_rank = tf.get_static_value(
                    ps.rank_from_shape(event_shape))
                if static_rank == 1:
                    return part
                new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                                      axis=-1)
                return tf.reshape(part, ps.cast(new_shape, tf.int32))

            if all(
                    tensorshape_util.is_fully_defined(s)
                    for s in tf.nest.flatten(self._distribution.event_shape)):
                x = tf.nest.map_structure(_reshape_part, x,
                                          self._distribution.dtype,
                                          self._distribution.event_shape)
            else:
                x = tf.nest.map_structure(
                    _reshape_part, x, self._distribution.dtype,
                    self._distribution.event_shape_tensor())

            return self._distribution.log_prob(x)
コード例 #5
0
 def _inverse(self, y):
   split_y = tf.split(y, self.block_sizes, axis=-1, num=len(self.bijectors))
   split_x = [b.inverse(y_) for b, y_ in zip(self.bijectors, split_y)]
   x = tf.concat(split_x, axis=-1)
   tensorshape_util.set_shape(x, y.shape)
   return x
コード例 #6
0
 def _forward(self, x):
   split_x = tf.split(x, self.block_sizes, axis=-1, num=len(self.bijectors))
   split_y = [b.forward(x_) for b, x_ in zip(self.bijectors, split_x)]
   y = tf.concat(split_y, axis=-1)
   tensorshape_util.set_shape(y, x.shape)
   return y
コード例 #7
0
def _replace_event_shape_in_shape_tensor(input_shape, event_shape_in,
                                         event_shape_out, validate_args):
    """Replaces the rightmost dims in a `Tensor` representing a shape.

  Args:
    input_shape: a rank-1 `Tensor` of integers
    event_shape_in: the event shape expected to be present in rightmost dims
      of `shape_in`.
    event_shape_out: the event shape with which to replace `event_shape_in` in
      the rightmost dims of `input_shape`.
    validate_args: Python `bool` indicating whether arguments should
      be checked for correctness.

  Returns:
    output_shape: A rank-1 integer `Tensor` with the same contents as
      `input_shape` except for the event dims, which are replaced with
      `event_shape_out`.
  """
    output_tensorshape, is_validated = _replace_event_shape_in_tensorshape(
        tensorshape_util.constant_value_as_shape(input_shape), event_shape_in,
        event_shape_out)

    # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function
    # correctly supports control_dependencies.
    validation_dependencies = (map(tf.identity,
                                   (event_shape_in,
                                    event_shape_out)) if validate_args else ())

    if (tensorshape_util.is_fully_defined(output_tensorshape)
            and (is_validated or not validate_args)):
        with tf.control_dependencies(validation_dependencies):
            output_shape = tf.convert_to_tensor(
                tensorshape_util.as_list(output_tensorshape),
                name='output_shape',
                dtype_hint=tf.int32)
        return output_shape, output_tensorshape

    with tf.control_dependencies(validation_dependencies):
        event_shape_in_ndims = (
            tf.size(event_shape_in)
            if tensorshape_util.num_elements(event_shape_in.shape) is None else
            tensorshape_util.num_elements(event_shape_in.shape))
        input_non_event_shape, input_event_shape = tf.split(
            input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

    additional_assertions = []
    if is_validated:
        pass
    elif validate_args:
        # Check that `input_event_shape` and `event_shape_in` are compatible in the
        # sense that they have equal entries in any position that isn't a `-1` in
        # `event_shape_in`. Note that our validations at construction time ensure
        # there is at most one such entry in `event_shape_in`.
        mask = event_shape_in >= 0
        explicit_input_event_shape = tf.boolean_mask(input_event_shape,
                                                     mask=mask)
        explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask)
        additional_assertions.append(
            assert_util.assert_equal(
                explicit_input_event_shape,
                explicit_event_shape_in,
                message='Input `event_shape` does not match `event_shape_in`.')
        )
        # We don't explicitly additionally verify
        # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split`
        # already makes this assertion.

    with tf.control_dependencies(additional_assertions):
        output_shape = tf.concat([input_non_event_shape, event_shape_out],
                                 axis=0,
                                 name='output_shape')

    return output_shape, output_tensorshape