Beispiel #1
0
 def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
     with self.test_session():
         tensor = tf.constant([1, 2], name="my_tensor")
         desired_rank = 2
         with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
             with tf.control_dependencies([tf.assert_rank_at_least(tensor, desired_rank)]):
                 tf.identity(tensor).eval()
def maybe_check_quadrature_param(param, name, validate_args):
  """Helper which checks validity of `loc` and `scale` init args."""
  with tf.name_scope(name="check_" + name, values=[param]):
    assertions = []
    if param.shape.ndims is not None:
      if param.shape.ndims == 0:
        raise ValueError("Mixing params must be a (batch of) vector; "
                         "{}.rank={} is not at least one.".format(
                             name, param.shape.ndims))
    elif validate_args:
      assertions.append(
          tf.assert_rank_at_least(
              param,
              1,
              message=("Mixing params must be a (batch of) vector; "
                       "{}.rank is not at least one.".format(name))))

    # TODO(jvdillon): Remove once we support k-mixtures.
    if param.shape.with_rank_at_least(1)[-1] is not None:
      if param.shape[-1].value != 1:
        raise NotImplementedError("Currently only bimixtures are supported; "
                                  "{}.shape[-1]={} is not 1.".format(
                                      name, param.shape[-1].value))
    elif validate_args:
      assertions.append(
          tf.assert_equal(
              tf.shape(param)[-1],
              1,
              message=("Currently only bimixtures are supported; "
                       "{}.shape[-1] is not 1.".format(name))))

    if assertions:
      return control_flow_ops.with_dependencies(assertions, param)
    return param
Beispiel #3
0
 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
   with self.test_session():
     tensor = tf.constant([1, 2], name="my_tensor")
     desired_rank = 1
     with tf.control_dependencies([tf.assert_rank_at_least(tensor,
                                                           desired_rank)]):
       tf.identity(tensor).eval()
Beispiel #4
0
 def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
     with self.test_session():
         tensor = tf.placeholder(tf.float32, name="my_tensor")
         desired_rank = 2
         with tf.control_dependencies([tf.assert_rank_at_least(tensor, desired_rank)]):
             with self.assertRaisesOpError("my_tensor.*rank"):
                 tf.identity(tensor).eval(feed_dict={tensor: [1, 2]})
Beispiel #5
0
 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
   with self.test_session():
     tensor = tf.placeholder(tf.float32, name="my_tensor")
     desired_rank = 1
     with tf.control_dependencies([tf.assert_rank_at_least(tensor,
                                                           desired_rank)]):
       tf.identity(tensor).eval(feed_dict={tensor: [1, 2]})
Beispiel #6
0
 def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
     with self.test_session():
         tensor = tf.constant(1, name="my_tensor")
         desired_rank = 1
         with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
             with tf.control_dependencies(
                 [tf.assert_rank_at_least(tensor, desired_rank)]):
                 tf.identity(tensor).eval()
Beispiel #7
0
 def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(
         self):
     with self.test_session():
         tensor = tf.constant(1, name="my_tensor")
         desired_rank = 0
         with tf.control_dependencies(
             [tf.assert_rank_at_least(tensor, desired_rank)]):
             tf.identity(tensor).eval()
Beispiel #8
0
 def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(
         self):
     with self.test_session():
         tensor = tf.constant([1, 2], name="my_tensor")
         desired_rank = 0
         with tf.control_dependencies(
             [tf.assert_rank_at_least(tensor, desired_rank)]):
             tf.identity(tensor).eval()
Beispiel #9
0
 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(
         self):
     with self.test_session():
         tensor = tf.placeholder(tf.float32, name="my_tensor")
         desired_rank = 1
         with tf.control_dependencies(
             [tf.assert_rank_at_least(tensor, desired_rank)]):
             tf.identity(tensor).eval(feed_dict={tensor: [1, 2]})
Beispiel #10
0
 def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
     with self.test_session():
         tensor = tf.placeholder(tf.float32, name="my_tensor")
         desired_rank = 1
         with tf.control_dependencies(
             [tf.assert_rank_at_least(tensor, desired_rank)]):
             with self.assertRaisesOpError("my_tensor.*rank"):
                 tf.identity(tensor).eval(feed_dict={tensor: 0})
 def _forward(self, x):
   if self.validate_args:
     is_matrix = tf.assert_rank_at_least(x, 2)
     shape = tf.shape(x)
     is_square = tf.assert_equal(shape[-2], shape[-1])
     x = control_flow_ops.with_dependencies([is_matrix, is_square], x)
   # For safety, explicitly zero-out the upper triangular part.
   x = tf.matrix_band_part(x, -1, 0)
   return tf.matmul(x, x, adjoint_b=True)
Beispiel #12
0
 def _forward(self, x):
     if self.validate_args:
         is_matrix = tf.assert_rank_at_least(x, 2)
         shape = tf.shape(x)
         is_square = tf.assert_equal(shape[-2], shape[-1])
         x = control_flow_ops.with_dependencies([is_matrix, is_square], x)
     # For safety, explicitly zero-out the upper triangular part.
     x = tf.matrix_band_part(x, -1, 0)
     return tf.matmul(x, x, adjoint_b=True)
Beispiel #13
0
    def __init__(self,
                 logits,
                 n_experiments,
                 dtype=None,
                 group_event_ndims=0):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype(
            [(self._logits, 'Multinomial.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        static_logits_shape = self._logits.get_shape()
        shape_err_msg = "logits should have rank >= 1."
        if static_logits_shape and (static_logits_shape.ndims < 1):
            raise ValueError(shape_err_msg)
        elif static_logits_shape and (
                static_logits_shape[-1].value is not None):
            self._n_categories = static_logits_shape[-1].value
        else:
            _assert_shape_op = tf.assert_rank_at_least(
                self._logits, 1, message=shape_err_msg)
            with tf.control_dependencies([_assert_shape_op]):
                self._logits = tf.identity(self._logits)
            self._n_categories = tf.shape(self._logits)[-1]

        sign_err_msg = "n_experiments must be positive"
        if isinstance(n_experiments, int):
            if n_experiments <= 0:
                raise ValueError(sign_err_msg)
            self._n_experiments = n_experiments
        else:
            try:
                n_experiments = tf.convert_to_tensor(n_experiments, tf.int32)
            except ValueError:
                raise TypeError('n_experiments must be int32')
            _assert_rank_op = tf.assert_rank(
                n_experiments, 0,
                message="n_experiments should be a scalar (0-D Tensor).")
            _assert_positive_op = tf.assert_greater(
                n_experiments, 0, message=sign_err_msg)
            with tf.control_dependencies([_assert_rank_op,
                                          _assert_positive_op]):
                self._n_experiments = tf.identity(n_experiments)

        super(Multinomial, self).__init__(
            dtype=dtype,
            param_dtype=param_dtype,
            is_continuous=False,
            is_reparameterized=False,
            group_event_ndims=group_event_ndims)
Beispiel #14
0
    def __init__(self,
                 samples,
                 event_ndims=0,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='Empirical'):
        """Initialize `Empirical` distributions.

    Args:
      samples: Numeric `Tensor` of shape [B1, ..., Bk, S, E1, ..., En]`,
        `k, n >= 0`. Samples or batches of samples on which the distribution
        is based. The first `k` dimensions index into a batch of independent
        distributions. Length of `S` dimension determines number of samples
        in each multiset. The last `n` dimension represents samples for each
        distribution. n is specified by argument event_ndims.
      event_ndims: Python `int32`, default `0`. number of dimensions for each
        event. When `0` this distribution has scalar samples. When `1` this
        distribution has vector-like samples.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value `NaN` to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if the rank of `samples` < event_ndims + 1.
    """

        parameters = locals()
        with tf.name_scope(name, values=[samples]):
            self._samples = tf.convert_to_tensor(samples, name='samples')
            self._event_ndims = event_ndims
            self._samples_axis = (
                (self.samples.shape.ndims or tf.rank(self.samples)) -
                self._event_ndims - 1)
            with tf.control_dependencies(
                [tf.assert_rank_at_least(self._samples, event_ndims + 1)]):
                samples_shape = util.prefer_static_shape(self._samples)
                self._num_samples = samples_shape[self._samples_axis]

        super(Empirical, self).__init__(
            dtype=self._samples.dtype,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._samples],
            name=name)
Beispiel #15
0
def kl_divergence_mv_gaussian_v2(sigma1, sigma2, mu1=None, mu2=None, mean_batch=True, name='kl_divergence_mv_gaussian'):
    # KL divergence between two multivariate Gaussian distributions
    # KL(N(mu1, sigma1) | N(mu2, sigma2))
    # sigma1 and sigma2 are Covariance objects
    # mu1 and mu2 tensors of [batch size, num features], if None, they are assumed to be zero
    with tf.name_scope(name):
        from mvg_distributions.covariance_representations import Covariance

        assert isinstance(sigma1, Covariance)
        assert isinstance(sigma2, Covariance)
        if mu1 is None:
            assert mu2 is None
        if mu2 is None:
            assert mu1 is None

        # This is equivalent to
        # tr_sig1_2 = tf.trace(tf.matmul(sigma2.precision, sigma1.covariance))
        # but it avoids doing the matmul for the off-diagonal elements
        tr_sig1_2 = tf.einsum('bij,bji->b', sigma2.precision, sigma1.covariance)

        k = tf.cast(tf.shape(sigma1.covariance)[1], sigma1.covariance.dtype)
        log_det = sigma2.log_det_covariance() - sigma1.log_det_covariance()

        if mu1 is not None:
            tf.assert_rank_at_least(mu1, 2)  # [Batch size, num features]
            tf.assert_rank_at_least(mu2, 2)

            sq_error = sigma2.x_precision_x(mu2 - mu1)

            kl_div = 0.5 * (tr_sig1_2 + sq_error - k + log_det)
        else:
            kl_div = 0.5 * (tr_sig1_2 - k + log_det)

        if mean_batch:
            kl_div = tf.reduce_mean(kl_div, axis=0)

        return kl_div
Beispiel #16
0
 def _prob(self, x):
   if self.validate_args:
     is_vector_check = tf.assert_rank_at_least(x, 1)
     right_vec_space_check = tf.assert_equal(
         self.event_shape_tensor(),
         tf.gather(tf.shape(x),
                   tf.rank(x) - 1),
         message=
         "Argument 'x' not defined in the same space R^k as this distribution")
     with tf.control_dependencies([is_vector_check]):
       with tf.control_dependencies([right_vec_space_check]):
         x = tf.identity(x)
   return tf.cast(
       tf.reduce_all(tf.abs(x - self.loc) <= self._slack, axis=-1),
       dtype=self.dtype)
Beispiel #17
0
 def _maybe_assert_valid_concentration(self, concentration, validate_args):
   """Checks the validity of the concentration parameter."""
   if not validate_args:
     return concentration
   return control_flow_ops.with_dependencies([
       tf.assert_positive(
           concentration,
           message="Concentration parameter must be positive."),
       tf.assert_rank_at_least(
           concentration, 1,
           message="Concentration parameter must have >=1 dimensions."),
       tf.assert_less(
           1, tf.shape(concentration)[-1],
           message="Concentration parameter must have event_size >= 2."),
   ], concentration)
Beispiel #18
0
 def _maybe_assert_valid_concentration(self, concentration, validate_args):
   """Checks the validity of the concentration parameter."""
   if not validate_args:
     return concentration
   return control_flow_ops.with_dependencies([
       tf.assert_positive(
           concentration,
           message="Concentration parameter must be positive."),
       tf.assert_rank_at_least(
           concentration, 1,
           message="Concentration parameter must have >=1 dimensions."),
       tf.assert_less(
           1, tf.shape(concentration)[-1],
           message="Concentration parameter must have event_size >= 2."),
   ], concentration)
Beispiel #19
0
 def _prob(self, x):
   if self.validate_args:
     is_vector_check = tf.assert_rank_at_least(x, 1)
     right_vec_space_check = tf.assert_equal(
         self.event_shape_tensor(),
         tf.gather(tf.shape(x),
                   tf.rank(x) - 1),
         message=
         "Argument 'x' not defined in the same space R^k as this distribution")
     with tf.control_dependencies([is_vector_check]):
       with tf.control_dependencies([right_vec_space_check]):
         x = tf.identity(x)
   return tf.cast(
       tf.reduce_all(tf.abs(x - self.loc) <= self._slack, axis=-1),
       dtype=self.dtype)
 def _assertions(self, x):
     if not self.validate_args:
         return []
     x_shape = tf.shape(x)
     is_matrix = tf.assert_rank_at_least(
         x, 2, message="Input must have rank at least 2.")
     is_square = tf.assert_equal(x_shape[-2],
                                 x_shape[-1],
                                 message="Input must be a square matrix.")
     diag_part_x = tf.matrix_diag_part(x)
     is_lower_triangular = tf.assert_equal(
         tf.matrix_band_part(x, 0, -1),  # Preserves triu, zeros rest.
         tf.matrix_diag(diag_part_x),
         message="Input must be lower triangular.")
     is_positive_diag = tf.assert_positive(
         diag_part_x,
         message="Input must have all positive diagonal entries.")
     return [is_matrix, is_square, is_lower_triangular, is_positive_diag]
 def _assertions(self, x):
   if not self.validate_args:
     return []
   x_shape = tf.shape(x)
   is_matrix = tf.assert_rank_at_least(
       x, 2,
       message="Input must have rank at least 2.")
   is_square = tf.assert_equal(
       x_shape[-2], x_shape[-1],
       message="Input must be a square matrix.")
   diag_part_x = tf.matrix_diag_part(x)
   is_lower_triangular = tf.assert_equal(
       tf.matrix_band_part(x, 0, -1),  # Preserves triu, zeros rest.
       tf.matrix_diag(diag_part_x),
       message="Input must be lower triangular.")
   is_positive_diag = tf.assert_positive(
       diag_part_x,
       message="Input must have all positive diagonal entries.")
   return [is_matrix, is_square, is_lower_triangular, is_positive_diag]
Beispiel #22
0
def assert_rank_at_least(tensor, k, name):
    """
    Whether the rank of `tensor` is at least k.

    :param tensor: A tensor to be checked.
    :param k: The least rank allowed.
    :param name: The name of `tensor` for error message.
    :return: The checked tensor.
    """
    static_shape = tensor.get_shape()
    shape_err_msg = '{} should have rank >= {}.'.format(name, k)
    if static_shape and (static_shape.ndims < k):
        raise ValueError(shape_err_msg)
    if not static_shape:
        _assert_shape_op = tf.assert_rank_at_least(
            tensor, k, message=shape_err_msg)
        with tf.control_dependencies([_assert_shape_op]):
            tensor = tf.identity(tensor)
    return tensor
Beispiel #23
0
 def __init__(self, tensor, fn, inputs):
   # Transforms are required to produce an output with a batch dimension. The
   # assertions below attempt to verify this. In the case of dense tensors the
   # check occurs statically if possible but falls back on a runtime check. In
   # the case of sparse tensors, the check happens at runtime.
   min_tensor_rank = 1
   if isinstance(tensor, tf.SparseTensor):
     with tf.control_dependencies(
         [tf.assert_greater_equal(tf.size(tensor.dense_shape),
                                  min_tensor_rank)]):
       tensor = tf.SparseTensor(indices=tf.identity(tensor.indices),
                                values=tensor.values,
                                dense_shape=tensor.dense_shape)
   else:
     with tf.control_dependencies(
         [tf.assert_rank_at_least(tensor, min_tensor_rank)]):
       tensor = tf.identity(tensor)
   super(_TransformedColumn, self).__init__(tensor)
   self._fn = fn
   self._inputs = inputs
Beispiel #24
0
def assert_rank_at_least_one(tensor, name):
    """
    Whether the rank of `tensor` is at least one.

    :param tensor: A tensor to be checked.
    :param name: The name of `tensor` for error message.
    :return: (checked tensor, the last dimension of `tensor`).
    """
    static_shape = tensor.get_shape()
    shape_err_msg = name + " should have rank >= 1."
    if static_shape and (static_shape.ndims < 1):
        raise ValueError(shape_err_msg)
    elif static_shape and (static_shape[-1].value is not None):
        return tensor, static_shape[-1].value
    else:
        _assert_shape_op = tf.assert_rank_at_least(tensor,
                                                   1,
                                                   message=shape_err_msg)
        with tf.control_dependencies([_assert_shape_op]):
            tensor = tf.identity(tensor)
        return tensor, tf.shape(tensor)[-1]
Beispiel #25
0
    def __init__(self,
                 alpha,
                 group_ndims=0,
                 check_numerics=False,
                 **kwargs):
        self._alpha = tf.convert_to_tensor(alpha)
        dtype = assert_same_float_dtype(
            [(self._alpha, 'Dirichlet.alpha')])

        static_alpha_shape = self._alpha.get_shape()
        shape_err_msg = "alpha should have rank >= 1."
        cat_err_msg = "n_categories (length of the last axis " \
                      "of alpha) should be at least 2."
        if static_alpha_shape and (static_alpha_shape.ndims < 1):
            raise ValueError(shape_err_msg)
        elif static_alpha_shape and (
                static_alpha_shape[-1].value is not None):
            self._n_categories = static_alpha_shape[-1].value
            if self._n_categories < 2:
                raise ValueError(cat_err_msg)
        else:
            _assert_shape_op = tf.assert_rank_at_least(
                self._alpha, 1, message=shape_err_msg)
            with tf.control_dependencies([_assert_shape_op]):
                self._alpha = tf.identity(self._alpha)
            self._n_categories = tf.shape(self._alpha)[-1]

            _assert_cat_op = tf.assert_greater_equal(
                self._n_categories, 2, message=cat_err_msg)
            with tf.control_dependencies([_assert_cat_op]):
                self._alpha = tf.identity(self._alpha)
        self._check_numerics = check_numerics

        super(Dirichlet, self).__init__(
            dtype=dtype,
            param_dtype=dtype,
            is_continuous=True,
            is_reparameterized=False,
            group_ndims=group_ndims,
            **kwargs)
 def _assertions(self, x):
   if not self.validate_args:
     return []
   shape = tf.shape(x)
   is_matrix = tf.assert_rank_at_least(
       x, 2, message="Input must have rank at least 2.")
   is_square = tf.assert_equal(
       shape[-2], shape[-1], message="Input must be a square matrix.")
   above_diagonal = tf.matrix_band_part(
       tf.matrix_set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0, -1)
   is_lower_triangular = tf.assert_equal(
       above_diagonal,
       tf.zeros_like(above_diagonal),
       message="Input must be lower triangular.")
   # A lower triangular matrix is nonsingular iff all its diagonal entries are
   # nonzero.
   diag_part = tf.matrix_diag_part(x)
   is_nonsingular = tf.assert_none_equal(
       diag_part,
       tf.zeros_like(diag_part),
       message="Input must have all diagonal entries nonzero.")
   return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
Beispiel #27
0
def assert_rank_at_least(x, ndims, message=None, name=None):
    """
    Assert the rank of `x` is at least `ndims`.

    Args:
        x: A tensor.
        ndims (int or tf.Tensor): An integer, or a 0-d integer tensor.
        message: Message to display when assertion failed.

    Returns:
        tf.Operation or None: The TensorFlow assertion operation,
            or None if can be statically asserted.
    """
    x = tf.convert_to_tensor(x)
    if not is_tensor_object(ndims) and get_static_shape(x) is not None:
        ndims = int(ndims)
        x_ndims = len(get_static_shape(x))
        if x_ndims < ndims:
            raise _make_assertion_error('rank(x) >= ndims',
                                        '{!r} < {!r}'.format(x_ndims,
                                                             ndims), message)
    else:
        return tf.assert_rank_at_least(x, ndims, message=message, name=name)
Beispiel #28
0
    def _detect_latent_axis(self, z, z_samples):
        if z_samples is None:
            latent_axis = None
        else:
            z_ndims = z.get_shape().ndims
            z_group_event_ndims = z.group_event_ndims
            if z_group_event_ndims is None:
                z_group_event_ndims = 0

            if z_ndims is None or is_dynamic_tensor_like(z_group_event_ndims):
                z_ndims_assertion = tf.assert_rank_at_least(
                    z,
                    z_group_event_ndims + 1,
                    message='`z_samples` is specified, but the log '
                    'lower-bounds for z is 0-dimensional')
                with tf.control_dependencies([z_ndims_assertion]):
                    z_ndims = tf.rank(z)
            elif z_ndims <= z_group_event_ndims:
                raise ValueError('`z_samples` is specified, but the log '
                                 'lower-bounds for z is 0-dimensional')

            latent_axis = z_group_event_ndims - z_ndims
        return latent_axis
Beispiel #29
0
 def _assertions(self, x):
     if not self.validate_args:
         return []
     shape = tf.shape(x)
     is_matrix = tf.assert_rank_at_least(
         x, 2, message="Input must have rank at least 2.")
     is_square = tf.assert_equal(shape[-2],
                                 shape[-1],
                                 message="Input must be a square matrix.")
     above_diagonal = tf.matrix_band_part(
         tf.matrix_set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0,
         -1)
     is_lower_triangular = tf.assert_equal(
         above_diagonal,
         tf.zeros_like(above_diagonal),
         message="Input must be lower triangular.")
     # A lower triangular matrix is nonsingular iff all its diagonal entries are
     # nonzero.
     diag_part = tf.matrix_diag_part(x)
     is_nonsingular = tf.assert_none_equal(
         diag_part,
         tf.zeros_like(diag_part),
         message="Input must have all diagonal entries nonzero.")
     return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
Beispiel #30
0
def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
    """Returns list of assertions related to `lu_solve` assumptions."""
    assertions = _lu_reconstruct_assertions(lower_upper, perm, validate_args)

    message = 'Input `rhs` must have at least 2 dimensions.'
    if rhs.shape.ndims is not None:
        if rhs.shape.ndims < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(tf.assert_rank_at_least(rhs, rank=2,
                                                  message=message))

    message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.'
    if (tf.dimension_value(lower_upper.shape[-1]) is not None
            and tf.dimension_value(rhs.shape[-2]) is not None):
        if lower_upper.shape[-1] != rhs.shape[-2]:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            tf.assert_equal(tf.shape(lower_upper)[-1],
                            tf.shape(rhs)[-2],
                            message=message))

    return assertions
Beispiel #31
0
def pinv(a, rcond=None, validate_args=False, name=None):
  """Compute the Moore-Penrose pseudo-inverse of a matrix.

  Calculate the [generalized inverse of a matrix](
  https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its
  singular-value decomposition (SVD) and including all large singular values.

  The pseudo-inverse of a matrix `A`, is defined as: "the matrix that 'solves'
  [the least-squares problem] `A @ x = b`," i.e., if `x_hat` is a solution, then
  `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if
  `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then
  `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1]

  This function is analogous to [`numpy.linalg.pinv`](
  https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html).
  It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
  default `rcond` is `1e-15`. Here the default is
  `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.

  Args:
    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
      pseudo-inverted.
    rcond: `Tensor` of small singular value cutoffs.  Singular values smaller
      (in modulus) than `rcond` * largest_singular_value (again, in modulus) are
      set to zero. Must broadcast against `tf.shape(a)[:-2]`.
      Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`.
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: "pinv".

  Returns:
    a_pinv: The pseudo-inverse of input `a`. Has same shape as `a` except
      rightmost two dimensions are transposed.

  Raises:
    TypeError: if input `a` does not have `float`-like `dtype`.
    ValueError: if input `a` has fewer than 2 dimensions.

  #### Examples

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp

  a = tf.constant([[1.,  0.4,  0.5],
                   [0.4, 0.2,  0.25],
                   [0.5, 0.25, 0.35]])
  tf.matmul(tfp.math.pinv(a), a)
  # ==> array([[1., 0., 0.],
               [0., 1., 0.],
               [0., 0., 1.]], dtype=float32)

  a = tf.constant([[1.,  0.4,  0.5,  1.],
                   [0.4, 0.2,  0.25, 2.],
                   [0.5, 0.25, 0.35, 3.]])
  tf.matmul(tfp.math.pinv(a), a)
  # ==> array([[ 0.76,  0.37,  0.21, -0.02],
               [ 0.37,  0.43, -0.33,  0.02],
               [ 0.21, -0.33,  0.81,  0.01],
               [-0.02,  0.02,  0.01,  1.  ]], dtype=float32)
  ```

  #### References

  [1]: G. Strang. "Linear Algebra and Its Applications, 2nd Ed." Academic Press,
       Inc., 1980, pp. 139-142.
  """
  with tf.name_scope(name, 'pinv', [a, rcond]):
    a = tf.convert_to_tensor(a, name='a')

    if not a.dtype.is_floating:
      raise TypeError('Input `a` must have `float`-like `dtype` '
                      '(saw {}).'.format(a.dtype.name))
    if a.shape.ndims is not None:
      if a.shape.ndims < 2:
        raise ValueError('Input `a` must have at least 2 dimensions '
                         '(saw: {}).'.format(a.shape.ndims))
    elif validate_args:
      assert_rank_at_least_2 = tf.assert_rank_at_least(
          a, rank=2,
          message='Input `a` must have at least 2 dimensions.')
      with tf.control_dependencies([assert_rank_at_least_2]):
        a = tf.identity(a)

    dtype = a.dtype.as_numpy_dtype

    if rcond is None:
      def get_dim_size(dim):
        if a.shape.ndims is not None and a.shape[dim].value is not None:
          return a.shape[dim].value
        return tf.shape(a)[dim]
      num_rows = get_dim_size(-2)
      num_cols = get_dim_size(-1)
      if isinstance(num_rows, int) and isinstance(num_cols, int):
        max_rows_cols = float(max(num_rows, num_cols))
      else:
        max_rows_cols = tf.cast(tf.maximum(num_rows, num_cols), dtype)
      rcond = 10. * max_rows_cols * np.finfo(dtype).eps

    rcond = tf.convert_to_tensor(rcond, dtype=dtype, name='rcond')

    # Calculate pseudo inverse via SVD.
    # Note: if a is symmetric then u == v. (We might observe additional
    # performance by explicitly setting `v = u` in such cases.)
    [
        singular_values,         # Sigma
        left_singular_vectors,   # U
        right_singular_vectors,  # V
    ] = tf.linalg.svd(a, full_matrices=False, compute_uv=True)

    # Saturate small singular values to inf. This has the effect of make
    # `1. / s = 0.` while not resulting in `NaN` gradients.
    cutoff = rcond * tf.reduce_max(singular_values, axis=-1)
    singular_values = tf.where(
        singular_values > cutoff[..., tf.newaxis],
        singular_values,
        tf.fill(tf.shape(singular_values), np.array(np.inf, dtype)))

    # Although `a == tf.matmul(u, s * v, transpose_b=True)` we swap
    # `u` and `v` here so that `tf.matmul(pinv(A), A) = tf.eye()`, i.e.,
    # a matrix inverse has "transposed" semantics.
    a_pinv = tf.matmul(
        right_singular_vectors / singular_values[..., tf.newaxis, :],
        left_singular_vectors,
        adjoint_b=True)

    if a.shape.ndims is not None:
      a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]]))

    return a_pinv
Beispiel #32
0
    def _forward_log_det_jacobian(self, x):
        # Let Y be a symmetric, positive definite matrix and write:
        #   Y = X X.T
        # where X is lower-triangular.
        #
        # Observe that,
        #   dY[i,j]/dX[a,b]
        #   = d/dX[a,b] { X[i,:] X[j,:] }
        #   = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] }
        #
        # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is
        # symmetric and X is lower-triangular, we need vectors of dimension:
        #   d = p (p + 1) / 2
        # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e.,
        #   k = { i (i + 1) / 2 + j   i>=j
        #       { undef               i<j
        # and assume zero-based indexes. When k is undef, the element is dropped.
        # Example:
        #           j      k
        #        0 1 2 3  /
        #    0 [ 0 . . . ]
        # i  1 [ 1 2 . . ]
        #    2 [ 3 4 5 . ]
        #    3 [ 6 7 8 9 ]
        # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With
        # slight abuse: k(i,j)=undef means the element is dropped.)
        #
        # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are
        # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b.
        # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since:
        # (1) j<=i<a thus i,j!=a.
        # (2) i=a>j  thus i,j!=a.
        #
        # Since the Jacobian is lower-triangular, we need only compute the product
        # of diagonal elements:
        #   d vec[Y] / d vec[X] @[k(i,j), k(i,j)]
        #   = X[j,j] + I[i=j] X[i,j]
        #   = 2 X[j,j].
        # Since there is a 2 X[j,j] term for every lower-triangular element of X we
        # conclude:
        #   |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
        diag = tf.matrix_diag_part(x)

        # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output
        # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the
        # output is unchanged.
        diag = self._make_columnar(diag)

        if self.validate_args:
            is_matrix = tf.assert_rank_at_least(
                x, 2, message="Input must be a (batch of) matrix.")
            shape = tf.shape(x)
            is_square = tf.assert_equal(
                shape[-2],
                shape[-1],
                message="Input must be a (batch of) square matrix.")
            # Assuming lower-triangular means we only need check diag>0.
            is_positive_definite = tf.assert_positive(
                diag, message="Input must be positive definite.")
            x = control_flow_ops.with_dependencies(
                [is_matrix, is_square, is_positive_definite], x)

        # Create a vector equal to: [p, p-1, ..., 2, 1].
        if x.get_shape().ndims is None or x.get_shape()[-1].value is None:
            p_int = tf.shape(x)[-1]
            p_float = tf.cast(p_int, dtype=x.dtype)
        else:
            p_int = x.get_shape()[-1].value
            p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype)
        exponents = tf.linspace(p_float, 1., p_int)

        sum_weighted_log_diag = tf.squeeze(tf.matmul(
            tf.log(diag), exponents[..., tf.newaxis]),
                                           axis=-1)
        fldj = p_float * np.log(2.) + sum_weighted_log_diag

        return fldj
    def __init__(self,
                 loc,
                 atol=None,
                 rtol=None,
                 is_vector=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="_BaseDeterministic"):
        """Initialize a batch of `_BaseDeterministic` distributions.

    The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf`
    computations, e.g. due to floating-point error.

    ```
    pmf(x; loc)
      = 1, if Abs(x - loc) <= atol + rtol * Abs(loc),
      = 0, otherwise.
    ```

    Args:
      loc: Numeric `Tensor`.  The point (or batch of points) on which this
        distribution is supported.
      atol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
        shape.  The absolute tolerance for comparing closeness to `loc`.
        Default is `0`.
      rtol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
        shape.  The relative tolerance for comparing closeness to `loc`.
        Default is `0`.
      is_vector:  Python `bool`.  If `True`, this is for `VectorDeterministic`,
        else `Deterministic`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError:  If `loc` is a scalar.
    """
        parameters = dict(locals())
        with tf.name_scope(name, values=[loc, atol, rtol]) as name:
            dtype = dtype_util.common_dtype([loc, atol, rtol])
            loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
            if is_vector and validate_args:
                msg = "Argument loc must be at least rank 1."
                if loc.get_shape().ndims is not None:
                    if loc.get_shape().ndims < 1:
                        raise ValueError(msg)
                else:
                    loc = control_flow_ops.with_dependencies(
                        [tf.assert_rank_at_least(loc, 1, message=msg)], loc)
            self._loc = loc

            super(_BaseDeterministic, self).__init__(
                dtype=self._loc.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=[self._loc],
                name=name)

            self._atol = self._get_tol(atol)
            self._rtol = self._get_tol(rtol)
            # Avoid using the large broadcast with self.loc if possible.
            if rtol is None:
                self._slack = self.atol
            else:
                self._slack = self.atol + self.rtol * tf.abs(self.loc)
Beispiel #34
0
  def __init__(self,
               loc,
               atol=None,
               rtol=None,
               is_vector=False,
               validate_args=False,
               allow_nan_stats=True,
               name="_BaseDeterministic"):
    """Initialize a batch of `_BaseDeterministic` distributions.

    The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf`
    computations, e.g. due to floating-point error.

    ```
    pmf(x; loc)
      = 1, if Abs(x - loc) <= atol + rtol * Abs(loc),
      = 0, otherwise.
    ```

    Args:
      loc: Numeric `Tensor`.  The point (or batch of points) on which this
        distribution is supported.
      atol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
        shape.  The absolute tolerance for comparing closeness to `loc`.
        Default is `0`.
      rtol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
        shape.  The relative tolerance for comparing closeness to `loc`.
        Default is `0`.
      is_vector:  Python `bool`.  If `True`, this is for `VectorDeterministic`,
        else `Deterministic`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError:  If `loc` is a scalar.
    """
    parameters = dict(locals())
    with tf.name_scope(name, values=[loc, atol, rtol]) as name:
      loc = tf.convert_to_tensor(loc, name="loc")
      if is_vector and validate_args:
        msg = "Argument loc must be at least rank 1."
        if loc.get_shape().ndims is not None:
          if loc.get_shape().ndims < 1:
            raise ValueError(msg)
        else:
          loc = control_flow_ops.with_dependencies(
              [tf.assert_rank_at_least(loc, 1, message=msg)], loc)
      self._loc = loc

      super(_BaseDeterministic, self).__init__(
          dtype=self._loc.dtype,
          reparameterization_type=tf.distributions.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=[self._loc],
          name=name)

      self._atol = self._get_tol(atol)
      self._rtol = self._get_tol(rtol)
      # Avoid using the large broadcast with self.loc if possible.
      if rtol is None:
        self._slack = self.atol
      else:
        self._slack = self.atol + self.rtol * tf.abs(self.loc)
  def _forward_log_det_jacobian(self, x):
    # Let Y be a symmetric, positive definite matrix and write:
    #   Y = X X.T
    # where X is lower-triangular.
    #
    # Observe that,
    #   dY[i,j]/dX[a,b]
    #   = d/dX[a,b] { X[i,:] X[j,:] }
    #   = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] }
    #
    # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is
    # symmetric and X is lower-triangular, we need vectors of dimension:
    #   d = p (p + 1) / 2
    # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e.,
    #   k = { i (i + 1) / 2 + j   i>=j
    #       { undef               i<j
    # and assume zero-based indexes. When k is undef, the element is dropped.
    # Example:
    #           j      k
    #        0 1 2 3  /
    #    0 [ 0 . . . ]
    # i  1 [ 1 2 . . ]
    #    2 [ 3 4 5 . ]
    #    3 [ 6 7 8 9 ]
    # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With
    # slight abuse: k(i,j)=undef means the element is dropped.)
    #
    # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are
    # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b.
    # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since:
    # (1) j<=i<a thus i,j!=a.
    # (2) i=a>j  thus i,j!=a.
    #
    # Since the Jacobian is lower-triangular, we need only compute the product
    # of diagonal elements:
    #   d vec[Y] / d vec[X] @[k(i,j), k(i,j)]
    #   = X[j,j] + I[i=j] X[i,j]
    #   = 2 X[j,j].
    # Since there is a 2 X[j,j] term for every lower-triangular element of X we
    # conclude:
    #   |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
    diag = tf.matrix_diag_part(x)

    # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output
    # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the
    # output is unchanged.
    diag = self._make_columnar(diag)

    if self.validate_args:
      is_matrix = tf.assert_rank_at_least(
          x, 2, message="Input must be a (batch of) matrix.")
      shape = tf.shape(x)
      is_square = tf.assert_equal(
          shape[-2],
          shape[-1],
          message="Input must be a (batch of) square matrix.")
      # Assuming lower-triangular means we only need check diag>0.
      is_positive_definite = tf.assert_positive(
          diag, message="Input must be positive definite.")
      x = control_flow_ops.with_dependencies(
          [is_matrix, is_square, is_positive_definite], x)

    # Create a vector equal to: [p, p-1, ..., 2, 1].
    if x.shape.ndims is None or x.shape[-1].value is None:
      p_int = tf.shape(x)[-1]
      p_float = tf.cast(p_int, dtype=x.dtype)
    else:
      p_int = x.shape[-1].value
      p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype)
    exponents = tf.linspace(p_float, 1., p_int)

    sum_weighted_log_diag = tf.squeeze(
        tf.matmul(tf.log(diag), exponents[..., tf.newaxis]), axis=-1)
    fldj = p_float * np.log(2.) + sum_weighted_log_diag

    # We finally need to undo adding an extra column in non-scalar cases
    # where there is a single matrix as input.
    if x.shape.ndims is not None:
      if x.shape.ndims == 2:
        fldj = tf.squeeze(fldj, axis=-1)
      return fldj

    shape = tf.shape(fldj)
    maybe_squeeze_shape = tf.concat([
        shape[:-1],
        distribution_util.pick_vector(
            tf.equal(tf.rank(x), 2),
            np.array([], dtype=np.int32), shape[-1:])], 0)
    return tf.reshape(fldj, maybe_squeeze_shape)
Beispiel #36
0
def assert_rank_at_least(x, rank, message=None):
    with tf.control_dependencies(
        [tf.assert_rank_at_least(x, rank, message=message)]):
        yield
Beispiel #37
0
    def validate_samples_shape(self, x, name=None):
        """Validate the shape of samples against this distribution.

        The shape of samples should be considered as matching the distribution,
        only if its rank is equal to or greater than the rank of ``batch_shape
        + value_shape``, and if it is broadcastable against ``batch_shape +
        value_shape``.

        Parameters
        ----------
        x : tf.Tensor
            The samples tensor to be validated.

        name : str
            Optional name of this operation.

        Returns
        -------
        tf.Tensor
            The original tensor `x` if the validation could be done
            with the static shape, or a new tensor coupled with dynamic
            assertions if the static shape cannot do full validation.
        """
        # check the static shape
        static_shape = self.static_batch_shape.concatenate(
            self.static_value_shape)
        x_static_shape = x.get_shape()

        static_compatible = True
        if static_shape.ndims is not None and x_static_shape.ndims is not None:
            need_dynamic_check = False
            x_dims = x_static_shape.as_list()
            dims = static_shape.as_list()
            if len(x_dims) < len(dims):
                static_compatible = False
            else:
                for x_dim, dim in zip(reversed(x_dims), reversed(dims)):
                    if x_dim is None or dim is None:
                        need_dynamic_check = True
                    elif not (x_dim == dim or x_dim == 1 or dim == 1):
                        static_compatible = False
                        break
        else:
            need_dynamic_check = True

        if not static_compatible:
            raise ValueError('The shape of `x` (%r) is not compatible '
                             'with the shape of distribution samples '
                             '(%r).' % (x_static_shape, static_shape))

        # check the dynamic shape
        if need_dynamic_check:
            with tf.name_scope(name, default_name='validate_samples_shape'):
                dynamic_shape = tf.concat(
                    [self.dynamic_batch_shape, self.dynamic_value_shape],
                    axis=0)
                rank_assertion = tf.assert_rank_at_least(
                    x,
                    tf.size(dynamic_shape),
                    message='Too few dimensions for samples to match '
                    'distribution %r' % self.variable_scope.name)
                b_shape = tf.broadcast_dynamic_shape(tf.shape(x),
                                                     dynamic_shape)
                with tf.control_dependencies([rank_assertion, b_shape]):
                    x = tf.identity(x)
        return x
 def entropyloss(self,prob):
     tf.assert_rank_at_least(tf.log(tf.log(tf.clip_by_value(prob, 1e-10, 1.0))),1,message="clipping is computed wrongly, wrong rank")
     tf.assert_rank_at_least(tf.log(prob),1,message="log(prob) is computed wrongly, wrong rank")
     entropy = -tf.reduce_mean(tf.exp(tf.add(tf.log(prob),tf.log(tf.log(tf.clip_by_value(prob, 1e-10, 1.0))))), axis=1)
     return entropy
Beispiel #39
0
def main(unused_argv):
    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    gamma = 0.95
    #RNN controller
    args = Parser().get_parser().parse_args()
    #Defining rnn
    val_accuracy = tf.placeholder(tf.float32)
    config = Config(args)
    net = Network(config)

    #Generate hyperparams
    A_t = tf.zeros((1, 1))
    # PPO implementation
    for i in range(FLAGS.train_epochs):
        outputs, prob, value = net.neural_search()
        hyperparams = net.gen_hyperparams(outputs)
        tf.assert_rank_at_least(tf.convert_to_tensor(prob),
                                1,
                                message="prob is the f*****g problem")
        c_1 = 1
        c_2 = 0.01
        if i > 0:
            #Polciy ratio
            #We write it in this tf.exp(tf.log(prob) - tf.log(old_prob)) instead of prob/old_prob
            #To improve numberical stability
            r = tf.exp(tf.log(prob) - tf.log(old_prob))
            #Encforcing the bellman equation
            delta_t = eval_results["accuracy"] + gamma * value - old_value
            A_t = delta_t + gamma * A_t
            L_clip = net.Lclip(eval_results["accuracy"], A_t)
            L_vf = net.Lvf(delta_t)
            entropy_penalty = -tf.reduce_sum(
                tf.exp(
                    tf.add(tf.log(prob),
                           tf.log(tf.log(tf.clip_by_value(prob, 1e-10,
                                                          1.0))))))
            tf.assert_rank(L_clip,
                           0,
                           message="L_clip is computed wrongly, wrong rank")
            tf.assert_rank(L_vf,
                           0,
                           message="L_vf is computed wrongly, wrong rank")
            tf.assert_rank(
                entropy_penalty,
                0,
                message="entropy_penalty is computed wrongly, wrong rank")
            total_loss = L_clip - c_1 * L_vf + c_2 * entropy_penalty
            tf.summary.scalar('loss', total_loss)

        tf_config = tf.ConfigProto(allow_soft_placement=True,
                                   log_device_placement=True)
        tf_config.gpu_options.allow_growth = True
        sess = tf.Session(config=tf_config)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter('train', sess.graph)

        # Set up a RunConfig to only save checkpoints once per training cycle.
        #run_config = tf.estimator.RunConfig().replace(session_config=tf.ConfigProto(log_device_placement=True),save_checkpoints_secs=1e9)
        print(sess.run(hyperparams))
        print(sess.run(value))
        #tmp is a temporary file which stores the encoded activation function,
        # it is used by main.py to pass the activation function to the childnetwork which reads from the file as the the program is being run.
        # It also acts as a cache file to store the final activation function found the agorthim
        with open("tmp", "w") as f:
            f.write(' '.join(map(str, sess.run(hyperparams))))

        run_config = tf.estimator.RunConfig().replace(
            session_config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=True))
        cifar_classifier = tf.estimator.Estimator(model_fn=cifar10_model_fn,
                                                  model_dir=FLAGS.model_dir,
                                                  config=run_config,
                                                  params={
                                                      'resnet_size':
                                                      FLAGS.resnet_size,
                                                      'data_format':
                                                      FLAGS.data_format,
                                                      'batch_size':
                                                      FLAGS.batch_size,
                                                  })

        for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
            tensors_to_log = {
                'learning_rate': 'learning_rate',
                'cross_entropy': 'cross_entropy',
                'train_accuracy': 'train_accuracy'
            }

            logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log,
                                                      every_n_iter=100)

            cifar_classifier.train(input_fn=lambda: input_fn(
                True, FLAGS.data_dir, FLAGS.batch_size, FLAGS.epochs_per_eval))

            # Evaluate the model and print results
            eval_results = cifar_classifier.evaluate(input_fn=lambda: input_fn(
                False, FLAGS.data_dir, FLAGS.batch_size))
            print(eval_results)

            old_prob = tf.identity(prob)
            old_value = tf.identity(value)

            if i > 0:
                print("Training RNN")
                tr_cont_step = net.update(total_loss)
                sess.run(tf.global_variables_initializer())
                _ = sess.run(
                    tr_cont_step,
                    feed_dict={val_accuracy: eval_results["accuracy"]})
                print("RNN Trained")
    assert A_t != tf.zeros(
        (1, 1)), "Advantage function was not computed correctly"