Ejemplo n.º 1
0
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
  """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
  batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
  if batch_shape_static.is_fully_defined():
    return np.int32(batch_shape_static.as_list()), batch_shape_static, []
  with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]):
    original_size = math_ops.reduce_prod(original_shape)
    implicit_dim = math_ops.equal(new_shape, -1)
    size_implicit_dim = (
        original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape)))
    new_ndims = array_ops.shape(new_shape)
    expanded_new_shape = array_ops.where(  # Assumes exactly one `-1`.
        implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape)
    validations = [] if not validate else [
        check_ops.assert_rank(
            original_shape, 1, message="Original shape must be a vector."),
        check_ops.assert_rank(
            new_shape, 1, message="New shape must be a vector."),
        check_ops.assert_less_equal(
            math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32),
            1,
            message="At most one dimension can be unknown."),
        check_ops.assert_positive(
            expanded_new_shape, message="Shape elements must be >=-1."),
        check_ops.assert_equal(
            math_ops.reduce_prod(expanded_new_shape),
            original_size,
            message="Shape sizes do not match."),
    ]
    return expanded_new_shape, batch_shape_static, validations
Ejemplo n.º 2
0
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
    """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
    batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
    if batch_shape_static.is_fully_defined():
        return np.int32(batch_shape_static.as_list()), batch_shape_static, []
    with ops.name_scope(name, "calculate_reshape",
                        [original_shape, new_shape]):
        original_size = math_ops.reduce_prod(original_shape)
        implicit_dim = math_ops.equal(new_shape, -1)
        size_implicit_dim = (
            original_size //
            math_ops.maximum(1, -math_ops.reduce_prod(new_shape)))
        new_ndims = array_ops.shape(new_shape)
        expanded_new_shape = array_ops.where_v2(  # Assumes exactly one `-1`.
            implicit_dim, array_ops.fill(new_ndims, size_implicit_dim),
            new_shape)
        validations = [] if not validate else [
            check_ops.assert_rank(
                original_shape, 1, message="Original shape must be a vector."),
            check_ops.assert_rank(
                new_shape, 1, message="New shape must be a vector."),
            check_ops.assert_less_equal(
                math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32),
                1,
                message="At most one dimension can be unknown."),
            check_ops.assert_positive(expanded_new_shape,
                                      message="Shape elements must be >=-1."),
            check_ops.assert_equal(math_ops.reduce_prod(expanded_new_shape),
                                   original_size,
                                   message="Shape sizes do not match."),
        ]
        return expanded_new_shape, batch_shape_static, validations
Ejemplo n.º 3
0
  def _check_shapes_dynamic(self, operator, v, diag):
    """Return (v, diag) with Assert dependencies, which check shape."""
    checks = []
    with ops.op_scope([operator, v, diag], 'check_shapes'):
      s_v = array_ops.shape(v)
      r_op = operator.rank()
      r_v = array_ops.rank(v)
      if diag is not None:
        s_d = array_ops.shape(diag)
        r_d = array_ops.rank(diag)

      # Check tensor rank.
      checks.append(check_ops.assert_rank(v, r_op))
      if diag is not None:
        checks.append(check_ops.assert_rank(diag, r_op - 1))

      # Check batch shape
      checks.append(check_ops.assert_equal(
          operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2])))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1])))

      # Check event shape
      checks.append(check_ops.assert_equal(
          operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2)))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1)))

      v = control_flow_ops.with_dependencies(checks, v)
      if diag is not None:
        diag = control_flow_ops.with_dependencies(checks, diag)
      return v, diag
Ejemplo n.º 4
0
    def _check_domain_range_possibly_add_asserts(self):
        """Static check of init arg `num_rows`, possibly add asserts."""
        # Possibly add asserts.
        if self._assert_proper_shapes:
            self._num_rows = control_flow_ops.with_dependencies([
                check_ops.assert_rank(
                    self._num_rows,
                    0,
                    message="Argument num_rows must be a 0-D Tensor."),
                check_ops.assert_non_negative(
                    self._num_rows,
                    message="Argument num_rows must be non-negative."),
            ], self._num_rows)
            self._num_columns = control_flow_ops.with_dependencies([
                check_ops.assert_rank(
                    self._num_columns,
                    0,
                    message="Argument num_columns must be a 0-D Tensor."),
                check_ops.assert_non_negative(
                    self._num_columns,
                    message="Argument num_columns must be non-negative."),
            ], self._num_columns)

        # Static checks.
        if not self._num_rows.dtype.is_integer:
            raise TypeError("Argument num_rows must be integer type.  Found:"
                            " %s" % self._num_rows)

        if not self._num_columns.dtype.is_integer:
            raise TypeError(
                "Argument num_columns must be integer type.  Found:"
                " %s" % self._num_columns)

        num_rows_static = self._num_rows_static
        num_columns_static = self._num_columns_static

        if num_rows_static is not None:
            if num_rows_static.ndim != 0:
                raise ValueError(
                    "Argument num_rows must be a 0-D Tensor.  Found:"
                    " %s" % num_rows_static)

            if num_rows_static < 0:
                raise ValueError(
                    "Argument num_rows must be non-negative.  Found:"
                    " %s" % num_rows_static)
        if num_columns_static is not None:
            if num_columns_static.ndim != 0:
                raise ValueError(
                    "Argument num_columns must be a 0-D Tensor.  Found:"
                    " %s" % num_columns_static)

            if num_columns_static < 0:
                raise ValueError(
                    "Argument num_columns must be non-negative.  Found:"
                    " %s" % num_columns_static)
    def _check_shapes_dynamic(self, operator, v, diag):
        """Return (v, diag) with Assert dependencies, which check shape."""
        checks = []
        with ops.name_scope("check_shapes", values=[operator, v, diag]):
            s_v = array_ops.shape(v)
            r_op = operator.rank()
            r_v = array_ops.rank(v)
            if diag is not None:
                s_d = array_ops.shape(diag)
                r_d = array_ops.rank(diag)

            # Check tensor rank.
            checks.append(
                check_ops.assert_rank(
                    v, r_op, message="v is not the same rank as operator."))
            if diag is not None:
                checks.append(
                    check_ops.assert_rank(
                        diag,
                        r_op - 1,
                        message="diag is not the same rank as operator."))

            # Check batch shape
            checks.append(
                check_ops.assert_equal(
                    operator.batch_shape(),
                    array_ops.strided_slice(s_v, [0], [r_v - 2]),
                    message="v does not have same batch shape as operator."))
            if diag is not None:
                checks.append(
                    check_ops.assert_equal(
                        operator.batch_shape(),
                        array_ops.strided_slice(s_d, [0], [r_d - 1]),
                        message=
                        "diag does not have same batch shape as operator."))

            # Check event shape
            checks.append(
                check_ops.assert_equal(
                    operator.vector_space_dimension(),
                    array_ops.gather(s_v, r_v - 2),
                    message="v does not have same event shape as operator."))
            if diag is not None:
                checks.append(
                    check_ops.assert_equal(
                        array_ops.gather(s_v, r_v - 1),
                        array_ops.gather(s_d, r_d - 1),
                        message="diag does not have same event shape as v."))

            v = control_flow_ops.with_dependencies(checks, v)
            if diag is not None:
                diag = control_flow_ops.with_dependencies(checks, diag)
            return v, diag
Ejemplo n.º 6
0
  def _check_domain_range_possibly_add_asserts(self):
    """Static check of init arg `num_rows`, possibly add asserts."""
    # Possibly add asserts.
    if self._assert_proper_shapes:
      self._num_rows = control_flow_ops.with_dependencies([
          check_ops.assert_rank(
              self._num_rows,
              0,
              message="Argument num_rows must be a 0-D Tensor."),
          check_ops.assert_non_negative(
              self._num_rows,
              message="Argument num_rows must be non-negative."),
      ], self._num_rows)
      self._num_columns = control_flow_ops.with_dependencies([
          check_ops.assert_rank(
              self._num_columns,
              0,
              message="Argument num_columns must be a 0-D Tensor."),
          check_ops.assert_non_negative(
              self._num_columns,
              message="Argument num_columns must be non-negative."),
      ], self._num_columns)

    # Static checks.
    if not self._num_rows.dtype.is_integer:
      raise TypeError("Argument num_rows must be integer type.  Found:"
                      " %s" % self._num_rows)

    if not self._num_columns.dtype.is_integer:
      raise TypeError("Argument num_columns must be integer type.  Found:"
                      " %s" % self._num_columns)

    num_rows_static = self._num_rows_static
    num_columns_static = self._num_columns_static

    if num_rows_static is not None:
      if num_rows_static.ndim != 0:
        raise ValueError("Argument num_rows must be a 0-D Tensor.  Found:"
                         " %s" % num_rows_static)

      if num_rows_static < 0:
        raise ValueError("Argument num_rows must be non-negative.  Found:"
                         " %s" % num_rows_static)
    if num_columns_static is not None:
      if num_columns_static.ndim != 0:
        raise ValueError("Argument num_columns must be a 0-D Tensor.  Found:"
                         " %s" % num_columns_static)

      if num_columns_static < 0:
        raise ValueError("Argument num_columns must be non-negative.  Found:"
                         " %s" % num_columns_static)
Ejemplo n.º 7
0
 def _sample_n(self, n, seed=None):
     n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
     if self.total_count.get_shape().ndims is not None:
         if self.total_count.get_shape().ndims != 0:
             raise NotImplementedError(
                 "Sample only supported for scalar number of draws.")
     elif self.validate_args:
         is_scalar = check_ops.assert_rank(
             n_draws,
             0,
             message="Sample only supported for scalar number of draws.")
         n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
     k = self.event_shape_tensor()[0]
     # Flatten batch dims so logits has shape [B, k],
     # where B = reduce_prod(self.batch_shape_tensor()).
     x = random_ops.multinomial(logits=array_ops.reshape(
         self.logits, [-1, k]),
                                num_samples=n * n_draws,
                                seed=seed)
     x = array_ops.reshape(x, shape=[-1, n, n_draws])
     x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k),
                             axis=-2)  # shape: [B, n, k]
     x = array_ops.transpose(x, perm=[1, 0, 2])
     final_shape = array_ops.concat(
         [[n], self.batch_shape_tensor(), [k]], 0)
     x = array_ops.reshape(x, final_shape)
     return math_ops.cast(x, self.dtype)
Ejemplo n.º 8
0
def _check_labels(labels, expected_labels_dimension):
    """Check labels type and shape."""
    with ops.name_scope(None, 'labels', (labels, )) as scope:
        labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
        if isinstance(labels, sparse_tensor.SparseTensor):
            raise ValueError('SparseTensor labels are not supported.')
        labels_shape = array_ops.shape(labels)
        err_msg = 'labels shape must be [batch_size, {}]'.format(
            expected_labels_dimension)
        assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
        with ops.control_dependencies([assert_rank]):
            static_shape = labels.shape
            if static_shape is not None:
                dim1 = static_shape[1]
                if (dim1 is not None) and (dim1 != expected_labels_dimension):
                    raise ValueError(
                        'Mismatched label shape. '
                        'Classifier configured with n_classes=%s.  Received %s. '
                        'Suggested Fix: check your n_classes argument to the estimator '
                        'and/or the shape of your label.' %
                        (expected_labels_dimension, dim1))
            assert_dimension = check_ops.assert_equal(
                expected_labels_dimension, labels_shape[1], message=err_msg)
            with ops.control_dependencies([assert_dimension]):
                return array_ops.identity(labels, name=scope)
Ejemplo n.º 9
0
  def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
    """Check whether event_ndims is atleast min_event_ndims."""
    event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
    event_ndims_ = tensor_util.constant_value(event_ndims)
    assertions = []

    if not event_ndims.dtype.is_integer:
      raise ValueError("Expected integer dtype, got dtype {}".format(
          event_ndims.dtype))

    if event_ndims_ is not None:
      if event_ndims.shape.ndims != 0:
        raise ValueError("Expected scalar event_ndims, got shape {}".format(
            event_ndims.shape))
      if min_event_ndims > event_ndims_:
        raise ValueError("event_ndims ({}) must be larger than "
                         "min_event_ndims ({})".format(
                             event_ndims_, min_event_ndims))
    elif self.validate_args:
      assertions += [
          check_ops.assert_greater_equal(event_ndims, min_event_ndims)]

    if event_ndims.shape.is_fully_defined():
      if event_ndims.shape.ndims != 0:
        raise ValueError("Expected scalar shape, got ndims {}".format(
            event_ndims.shape.ndims))

    elif self.validate_args:
      assertions += [
          check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")]
    return assertions
Ejemplo n.º 10
0
 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.n, dtype=dtypes.int32)
   if self.n.get_shape().ndims is not None:
     if self.n.get_shape().ndims != 0:
       raise NotImplementedError(
           "Sample only supported for scalar number of draws.")
   elif self.validate_args:
     is_scalar = check_ops.assert_rank(
         n_draws, 0,
         message="Sample only supported for scalar number of draws.")
     n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
   k = self.event_shape()[0]
   unnormalized_logits = array_ops.reshape(
       math_ops.log(random_ops.random_gamma(
           shape=[n],
           alpha=self.alpha,
           dtype=self.dtype,
           seed=seed)),
       shape=[-1, k])
   draws = random_ops.multinomial(
       logits=unnormalized_logits,
       num_samples=n_draws,
       seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
   x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
                           reduction_indices=-2)
   final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0)
   return array_ops.reshape(x, final_shape)
Ejemplo n.º 11
0
def _check_labels(labels, expected_labels_dimension):
  """Check labels type and shape."""
  with ops.name_scope(None, 'labels', (labels,)) as scope:
    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
    if isinstance(labels, sparse_tensor.SparseTensor):
      raise ValueError('SparseTensor labels are not supported.')
    labels_shape = array_ops.shape(labels)
    err_msg = 'labels shape must be [batch_size, {}]'.format(
        expected_labels_dimension)
    assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
    with ops.control_dependencies([assert_rank]):
      static_shape = labels.shape
      if static_shape is not None:
        dim1 = static_shape[1]
        if (dim1 is not None) and (dim1 != expected_labels_dimension):
          raise ValueError(
              'Mismatched label shape. '
              'Classifier configured with n_classes=%s.  Received %s. '
              'Suggested Fix: check your n_classes argument to the estimator '
              'and/or the shape of your label.' %
              (expected_labels_dimension, dim1))
      assert_dimension = check_ops.assert_equal(
          expected_labels_dimension, labels_shape[1], message=err_msg)
      with ops.control_dependencies([assert_dimension]):
        return array_ops.identity(labels, name=scope)
Ejemplo n.º 12
0
 def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
   tensor = constant_op.constant([1, 2], name="my_tensor")
   desired_rank = 2
   with self.assertRaisesRegexp(ValueError, "rank"):
     with ops.control_dependencies(
         [check_ops.assert_rank(tensor, desired_rank)]):
       self.evaluate(array_ops.identity(tensor))
Ejemplo n.º 13
0
 def _sample_n(self, n, seed=None):
     n_draws = math_ops.cast(self.n, dtype=dtypes.int32)
     if self.n.get_shape().ndims is not None:
         if self.n.get_shape().ndims != 0:
             raise NotImplementedError(
                 "Sample only supported for scalar number of draws.")
     elif self.validate_args:
         is_scalar = check_ops.assert_rank(
             n_draws,
             0,
             message="Sample only supported for scalar number of draws.")
         n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
     k = self.event_shape()[0]
     unnormalized_logits = array_ops.reshape(math_ops.log(
         random_ops.random_gamma(shape=[n],
                                 alpha=self.alpha,
                                 dtype=self.dtype,
                                 seed=seed)),
                                             shape=[-1, k])
     draws = random_ops.multinomial(logits=unnormalized_logits,
                                    num_samples=n_draws,
                                    seed=distribution_util.gen_new_seed(
                                        seed, salt="dirichlet_multinomial"))
     x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
                             reduction_indices=-2)
     final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0)
     return array_ops.reshape(x, final_shape)
Ejemplo n.º 14
0
 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
   if self.total_count.get_shape().ndims is not None:
     if self.total_count.get_shape().ndims != 0:
       raise NotImplementedError(
           "Sample only supported for scalar number of draws.")
   elif self.validate_args:
     is_scalar = check_ops.assert_rank(
         n_draws, 0,
         message="Sample only supported for scalar number of draws.")
     n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
   k = self.event_shape_tensor()[0]
   # Flatten batch dims so logits has shape [B, k],
   # where B = reduce_prod(self.batch_shape_tensor()).
   x = random_ops.multinomial(
       logits=array_ops.reshape(self.logits, [-1, k]),
       num_samples=n * n_draws,
       seed=seed)
   x = array_ops.reshape(x, shape=[-1, n, n_draws])
   x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k),
                           axis=-2)  # shape: [B, n, k]
   x = array_ops.transpose(x, perm=[1, 0, 2])
   final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
   x = array_ops.reshape(x, final_shape)
   return math_ops.cast(x, self.dtype)
    def _check_batch_shape_possibly_add_asserts(self):
        """Static check of init arg `batch_shape`, possibly add asserts."""
        if self._batch_shape_arg is None:
            return

        # Possibly add asserts
        if self._assert_proper_shapes:
            self._batch_shape_arg = control_flow_ops.with_dependencies([
                check_ops.assert_rank(
                    self._batch_shape_arg,
                    1,
                    message="Argument batch_shape must be a 1-D Tensor."),
                check_ops.assert_non_negative(
                    self._batch_shape_arg,
                    message="Argument batch_shape must be non-negative."),
            ], self._batch_shape_arg)

        # Static checks
        if not self._batch_shape_arg.dtype.is_integer:
            raise TypeError(
                "Argument batch_shape must be integer type.  Found:"
                " %s" % self._batch_shape_arg)

        if self._batch_shape_static is None:
            return  # Cannot do any other static checks.

        if self._batch_shape_static.ndim != 1:
            raise ValueError(
                "Argument batch_shape must be a 1-D Tensor.  Found:"
                " %s" % self._batch_shape_static)

        if np.any(self._batch_shape_static < 0):
            raise ValueError(
                "Argument batch_shape must be non-negative.  Found:"
                "%s" % self._batch_shape_static)
Ejemplo n.º 16
0
  def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
    """Check whether event_ndims is atleast min_event_ndims."""
    event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
    event_ndims_ = tensor_util.constant_value(event_ndims)
    assertions = []

    if not event_ndims.dtype.is_integer:
      raise ValueError("Expected integer dtype, got dtype {}".format(
          event_ndims.dtype))

    if event_ndims_ is not None:
      if event_ndims.shape.ndims != 0:
        raise ValueError("Expected scalar event_ndims, got shape {}".format(
            event_ndims.shape))
      if min_event_ndims > event_ndims_:
        raise ValueError("event_ndims ({}) must be larger than "
                         "min_event_ndims ({})".format(
                             event_ndims_, min_event_ndims))
    elif self.validate_args:
      assertions += [
          check_ops.assert_greater_equal(event_ndims, min_event_ndims)]

    if event_ndims.shape.is_fully_defined():
      if event_ndims.shape.ndims != 0:
        raise ValueError("Expected scalar shape, got ndims {}".format(
            event_ndims.shape.ndims))

    elif self.validate_args:
      assertions += [
          check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")]
    return assertions
Ejemplo n.º 17
0
def _check_logits(logits, expected_logits_dimension):
    """Check logits type and shape."""
    with ops.name_scope(None, 'logits', (logits, )) as scope:
        logits = math_ops.to_float(logits)
        logits_shape = array_ops.shape(logits)
        assert_rank = check_ops.assert_rank(
            logits,
            2,
            data=[logits_shape],
            message='logits shape must be [batch_size, logits_dimension]')
        with ops.control_dependencies([assert_rank]):
            static_shape = logits.shape
            if static_shape is not None:
                dim1 = static_shape[1]
                if (dim1 is not None) and (dim1 != expected_logits_dimension):
                    raise ValueError(
                        'logits shape must be [batch_size, logits_dimension], got %s.'
                        % (static_shape, ))
            assert_dimension = check_ops.assert_equal(
                expected_logits_dimension,
                logits_shape[1],
                data=[logits_shape],
                message='logits shape must be [batch_size, logits_dimension]')
            with ops.control_dependencies([assert_dimension]):
                return array_ops.identity(logits, name=scope)
Ejemplo n.º 18
0
  def _check_batch_shape_possibly_add_asserts(self):
    """Static check of init arg `batch_shape`, possibly add asserts."""
    if self._batch_shape_arg is None:
      return

    # Possibly add asserts
    if self._assert_proper_shapes:
      self._batch_shape_arg = control_flow_ops.with_dependencies(
          [
              check_ops.assert_rank(
                  self._batch_shape_arg,
                  1,
                  message="Argument batch_shape must be a 1-D Tensor."),
              check_ops.assert_non_negative(
                  self._batch_shape_arg,
                  message="Argument batch_shape must be non-negative."),
          ],
          self._batch_shape_arg)

    # Static checks
    if not self._batch_shape_arg.dtype.is_integer:
      raise TypeError("Argument batch_shape must be integer type.  Found:"
                      " %s" % self._batch_shape_arg)

    if self._batch_shape_static is None:
      return  # Cannot do any other static checks.

    if self._batch_shape_static.ndim != 1:
      raise ValueError("Argument batch_shape must be a 1-D Tensor.  Found:"
                       " %s" % self._batch_shape_static)

    if np.any(self._batch_shape_static < 0):
      raise ValueError("Argument batch_shape must be non-negative.  Found:"
                       "%s" % self._batch_shape_static)
Ejemplo n.º 19
0
 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
   with self.test_session():
     tensor = constant_op.constant([1, 2], name="my_tensor")
     desired_rank = 1
     with ops.control_dependencies(
         [check_ops.assert_rank(tensor, desired_rank)]):
       array_ops.identity(tensor).eval()
Ejemplo n.º 20
0
def _check_and_reshape_dense_labels(labels, expected_labels_dimension):
    """Checks dense labels type and shape and reshapes to 2D Tensor."""
    with ops.name_scope(None, 'labels', (labels, )) as scope:
        labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
        if isinstance(labels, sparse_tensor.SparseTensor):
            raise ValueError(
                'SparseTensor labels are not supported. '
                'labels must be a Tensor of shape [batch_size, %s]. '
                'Suggested Fix (1): Check the label feature in your data. '
                'Each example must contain %s value(s). If not, your choice of label '
                'was probably incorrect. '
                'Suggested Fix (2): In your input_fn, use '
                'tf.sparse_tensor_to_dense() to turn labels into a Tensor.'
                '' % (expected_labels_dimension, expected_labels_dimension))
        labels = _maybe_expand_dim(labels)
        labels_shape = array_ops.shape(labels)
        err_msg = 'labels shape must be [batch_size, {}]'.format(
            expected_labels_dimension)
        assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
        with ops.control_dependencies([assert_rank]):
            static_shape = labels.shape
            if static_shape is not None:
                dim1 = static_shape[1]
                if (dim1 is not None) and (dim1 != expected_labels_dimension):
                    raise ValueError(
                        'Mismatched label shape. '
                        'Classifier configured with n_classes=%s.  Received %s. '
                        'Suggested Fix: check your n_classes argument to the estimator '
                        'and/or the shape of your label.' %
                        (expected_labels_dimension, dim1))
            assert_dimension = check_ops.assert_equal(
                expected_labels_dimension, labels_shape[1], message=err_msg)
            with ops.control_dependencies([assert_dimension]):
                return array_ops.identity(labels, name=scope)
Ejemplo n.º 21
0
 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
   with self.test_session():
     tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
     desired_rank = 1
     with ops.control_dependencies(
         [check_ops.assert_rank(tensor, desired_rank)]):
       array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
Ejemplo n.º 22
0
  def from_row_limits(cls, row_limits, validate=True, preferred_dtype=None):
    """Creates a `RowPartition` with rows partitioned by `row_limits`.

    Equivalent to: `from_row_splits(values, concat([0, row_limits], axis=0))`.

    Args:
      row_limits: A 1-D integer tensor with shape `[nrows]`.  Must be sorted in
        ascending order.
      validate: If true, then use assertions to check that the arguments form a
        valid `RowPartition`.
      preferred_dtype: If row_limits has an unspecified type, use this one. If
        preferred_dtype is None, defaults to dtypes.int64.

    Returns:
      A `RowPartition`.
    """
    if not isinstance(validate, bool):
      raise TypeError("validate must have type bool")
    with ops.name_scope(None, "RowPartitionFromRowLimits", [row_limits]):
      row_limits = cls._convert_row_partition(row_limits, "row_limits",
                                              preferred_dtype)
      row_limits.shape.assert_has_rank(1)

      if validate:
        msg = "Arguments to from_row_limits do not form a valid RaggedTensor"
        checks = [
            check_ops.assert_rank(row_limits, 1, message=msg),
            check_ops.assert_non_negative(row_limits[:1], message=msg),
            _assert_monotonic_increasing(row_limits, message=msg),
        ]
        row_limits = control_flow_ops.with_dependencies(checks, row_limits)

      zero = array_ops.zeros([1], row_limits.dtype)
      row_splits = array_ops.concat([zero, row_limits], axis=0)
      return cls(row_splits=row_splits, internal=_row_partition_factory_key)
Ejemplo n.º 23
0
def lu_reconstruct_assertions(lower_upper, perm, validate_args):
    """Returns list of assertions related to `lu_reconstruct` assumptions."""
    assertions = []

    message = 'Input `lower_upper` must have at least 2 dimensions.'
    if lower_upper.shape.rank is not None and lower_upper.shape.rank < 2:
        raise ValueError(message)
    elif validate_args:
        assertions.append(
            check_ops.assert_rank_at_least_v2(lower_upper,
                                              rank=2,
                                              message=message))

    message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
    if lower_upper.shape.rank is not None and perm.shape.rank is not None:
        if lower_upper.shape.rank != perm.shape.rank + 1:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            check_ops.assert_rank(lower_upper,
                                  rank=array_ops.rank(perm) + 1,
                                  message=message))

    message = '`lower_upper` must be square.'
    if lower_upper.shape[:-2].is_fully_defined():
        if lower_upper.shape[-2] != lower_upper.shape[-1]:
            raise ValueError(message)
    elif validate_args:
        m, n = array_ops.split(array_ops.shape(lower_upper)[-2:],
                               num_or_size_splits=2)
        assertions.append(check_ops.assert_equal(m, n, message=message))

    return assertions
Ejemplo n.º 24
0
def concatenate_context_input(context_input, sequence_input):
  """Replicates `context_input` across all timesteps of `sequence_input`.

  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
  This value is appended to `sequence_input` on dimension 2 and the result is
  returned.

  Args:
    context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
    sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
      padded_length, d0]`.

  Returns:
    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
    d0 + d1]`.

  Raises:
    ValueError: If `sequence_input` does not have rank 3 or `context_input` does
      not have rank 2.
  """
  seq_rank_check = check_ops.assert_rank(
      sequence_input,
      3,
      message='sequence_input must have rank 3',
      data=[array_ops.shape(sequence_input)])
  seq_type_check = check_ops.assert_type(
      sequence_input,
      dtypes.float32,
      message='sequence_input must have dtype float32; got {}.'.format(
          sequence_input.dtype))
  ctx_rank_check = check_ops.assert_rank(
      context_input,
      2,
      message='context_input must have rank 2',
      data=[array_ops.shape(context_input)])
  ctx_type_check = check_ops.assert_type(
      context_input,
      dtypes.float32,
      message='context_input must have dtype float32; got {}.'.format(
          context_input.dtype))
  with ops.control_dependencies(
      [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
    padded_length = array_ops.shape(sequence_input)[1]
    tiled_context_input = array_ops.tile(
        array_ops.expand_dims(context_input, 1),
        array_ops.concat([[1], [padded_length], [1]], 0))
  return array_ops.concat([sequence_input, tiled_context_input], 2)
Ejemplo n.º 25
0
def _concatenate_context_input(sequence_input, context_input):
    """Replicates `context_input` accross all timesteps of `sequence_input`.

  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
  This value is appended to `sequence_input` on dimension 2 and the result is
  returned.

  Args:
    sequence_input: a `Tensor` of dtype `float32` and shape `[batch_size,
      padded_length, d0]`.
    context_input: a `Tensor` of dtype `float32` and shape `[batch_size, d1]`.

  Returns:
    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
    d0 + d1]`.

  Raises:
    ValueError: if `sequence_input` does not have rank 3 or `context_input` does
      not have rank 2.
  """
    seq_rank_check = check_ops.assert_rank(
        sequence_input,
        3,
        message='sequence_input must have rank 3',
        data=[array_ops.shape(sequence_input)])
    seq_type_check = check_ops.assert_type(
        sequence_input,
        dtypes.float32,
        message='sequence_input must have dtype float32; got {}.'.format(
            sequence_input.dtype))
    ctx_rank_check = check_ops.assert_rank(
        context_input,
        2,
        message='context_input must have rank 2',
        data=[array_ops.shape(context_input)])
    ctx_type_check = check_ops.assert_type(
        context_input,
        dtypes.float32,
        message='context_input must have dtype float32; got {}.'.format(
            context_input.dtype))
    with ops.control_dependencies(
        [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
        padded_length = array_ops.shape(sequence_input)[1]
        tiled_context_input = array_ops.tile(
            array_ops.expand_dims(context_input, 1),
            array_ops.concat(0, [[1], [padded_length], [1]]))
    return array_ops.concat(2, [sequence_input, tiled_context_input])
Ejemplo n.º 26
0
def validate_init_args(
    distribution,
    batch_shape,
    validate_args,
    batch_shape_static):
  """Helper to __init__ which makes or raises assertions."""
  with ops.name_scope(name="validate_init_args",
                      values=[batch_shape] + distribution._graph_parents):  # pylint: disable=protected-access
    runtime_assertions = []

    if batch_shape.shape.ndims is not None:
      if batch_shape.shape.ndims != 1:
        raise ValueError("`batch_shape` must be a vector "
                         "(saw rank: {}).".format(
                             batch_shape.shape.ndims))
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_rank(
              batch_shape,
              1,
              message="`batch_shape` must be a vector.",
              name="assert_batch_shape_is_vector"),
      ]

    batch_size_static = np.prod(batch_shape_static)
    dist_batch_size_static = (
        None if not distribution.batch_shape.is_fully_defined()
        else np.prod(distribution.batch_shape).value)

    if batch_size_static is not None and dist_batch_size_static is not None:
      if batch_size_static != dist_batch_size_static:
        raise ValueError("`batch_shape` size ({}) must match "
                         "`distribution.batch_shape` size ({}).".format(
                             batch_size_static,
                             dist_batch_size_static))
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_equal(
              math_ops.reduce_prod(batch_shape),
              math_ops.reduce_prod(distribution.batch_shape_tensor()),
              message=("`batch_shape` size must match "
                       "`distributions.batch_shape` size."),
              name="assert_batch_size"),
      ]

    if batch_shape_static is not None:
      if np.any(batch_shape_static < 1):
        raise ValueError("`batch_shape` elements must be positive "
                         "(i.e., larger than zero).")
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_positive(
              batch_shape,
              message=("`batch_shape` elements must be positive "
                       "(i.e., larger than zero)."),
              name="assert_batch_shape_positive")
      ]

    return runtime_assertions
Ejemplo n.º 27
0
def validate_init_args(
    distribution,
    batch_shape,
    validate_args,
    batch_shape_static):
  """Helper to __init__ which makes or raises assertions."""
  with ops.name_scope(name="validate_init_args",
                      values=[batch_shape] + distribution._graph_parents):  # pylint: disable=protected-access
    runtime_assertions = []

    if batch_shape.shape.ndims is not None:
      if batch_shape.shape.ndims != 1:
        raise ValueError("`batch_shape` must be a vector "
                         "(saw rank: {}).".format(
                             batch_shape.shape.ndims))
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_rank(
              batch_shape,
              1,
              message="`batch_shape` must be a vector.",
              name="assert_batch_shape_is_vector"),
      ]

    batch_size_static = np.prod(batch_shape_static)
    dist_batch_size_static = (
        None if not distribution.batch_shape.is_fully_defined()
        else np.prod(distribution.batch_shape).value)

    if batch_size_static is not None and dist_batch_size_static is not None:
      if batch_size_static != dist_batch_size_static:
        raise ValueError("`batch_shape` size ({}) must match "
                         "`distribution.batch_shape` size ({}).".format(
                             batch_size_static,
                             dist_batch_size_static))
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_equal(
              math_ops.reduce_prod(batch_shape),
              math_ops.reduce_prod(distribution.batch_shape_tensor()),
              message=("`batch_shape` size must match "
                       "`distributions.batch_shape` size."),
              name="assert_batch_size"),
      ]

    if batch_shape_static is not None:
      if np.any(batch_shape_static < 1):
        raise ValueError("`batch_shape` elements must be positive "
                         "(i.e., larger than zero).")
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_positive(
              batch_shape,
              message=("`batch_shape` elements must be positive "
                       "(i.e., larger than zero)."),
              name="assert_batch_shape_positive")
      ]

    return runtime_assertions
Ejemplo n.º 28
0
 def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self):
   with self.test_session():
     tensor = constant_op.constant([1, 2], name="my_tensor")
     desired_rank = 0
     with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
       with ops.control_dependencies(
           [check_ops.assert_rank(tensor, desired_rank)]):
         array_ops.identity(tensor).eval()
Ejemplo n.º 29
0
 def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
   with self.test_session():
     tensor = constant_op.constant([1, 2], name="my_tensor")
     desired_rank = 2
     with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
       with ops.control_dependencies(
           [check_ops.assert_rank(tensor, desired_rank)]):
         array_ops.identity(tensor).eval()
Ejemplo n.º 30
0
 def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
   with self.test_session():
     tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
     desired_rank = 2
     with ops.control_dependencies(
         [check_ops.assert_rank(tensor, desired_rank)]):
       with self.assertRaisesOpError("my_tensor.*rank"):
         array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
Ejemplo n.º 31
0
    def from_row_splits(cls,
                        row_splits,
                        name=None,
                        validate=True,
                        preferred_dtype=None):
        """Creates a `RowPartition` with rows partitioned by `row_splits`.

    A `RaggedTensor` constructed with this corresponds with the python list
    defined by:

    ```python
    result = [values[row_splits[i]:row_splits[i + 1]]
              for i in range(len(row_splits) - 1)]
    ```

    Args:
      row_splits: A 1-D integer tensor with shape `[nrows+1]`.  Must not be
        empty, and must be sorted in ascending order.  `row_splits[0]` must be
        zero.
      name: A name prefix for the RaggedTensor (optional).
      validate: If true, then use assertions to check that the arguments form a
        valid `RowPartition`.
      preferred_dtype: If row_splits has an unspecified type, use this one. If
        preferred_dtype is None, defaults to dtypes.int64.

    Returns:
      A `RowPartition`.

    Raises:
      ValueError: If `row_splits` is an empty list.

    """
        if not isinstance(validate, bool):
            raise TypeError("validate must have type bool")
        if isinstance(row_splits, (list, tuple)) and not row_splits:
            raise ValueError("row_splits tensor may not be empty.")
        if isinstance(row_splits, tensor_spec.TensorSpec):
            return cls(row_splits=row_splits, internal=True)

        with ops.name_scope(name, "RowPartitionFromRowSplits", [row_splits]):
            row_splits = cls._convert_row_partition(row_splits, "row_splits",
                                                    preferred_dtype)
            row_splits.shape.assert_has_rank(1)

            if validate:
                msg = "Arguments to from_row_splits do not form a valid RaggedTensor:"
                checks = [
                    check_ops.assert_rank(row_splits,
                                          1,
                                          message=(msg + "rank")),
                    _assert_zero(row_splits[0], message=(msg + "zero")),
                    _assert_monotonic_increasing(row_splits,
                                                 message=(msg + "monotonic")),
                ]
                row_splits = control_flow_ops.with_dependencies(
                    checks, row_splits)

            return cls(row_splits=row_splits, internal=True)
Ejemplo n.º 32
0
 def test_raises_if_rank_is_not_scalar_dynamic(self):
   with self.test_session():
     tensor = constant_op.constant(
         [1, 2], dtype=dtypes.float32, name="my_tensor")
     rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor")
     with self.assertRaisesOpError("Rank must be a scalar"):
       with ops.control_dependencies(
           [check_ops.assert_rank(tensor, rank_tensor)]):
         array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]})
Ejemplo n.º 33
0
    def from_row_splits(cls, row_splits, validate=True, preferred_dtype=None):
        """Creates a `RowPartition` with rows partitioned by `row_splits`.

    This `RowPartition` divides a sequence `values` into rows by indicating
    where each row begins and ends:

    ```python
    partitioned_rows = []
    for i in range(len(row_splits) - 1):
      row_start = row_splits[i]
      row_end = row_splits[i + 1]
      partitioned_rows.append(values[row_start:row_end])
    ```

    Args:
      row_splits: A 1-D integer tensor with shape `[nrows+1]`.  Must not be
        empty, and must be sorted in ascending order.  `row_splits[0]` must be
        zero.
      validate: If true, then use assertions to check that the arguments form a
        valid `RowPartition`.
      preferred_dtype: If row_splits has an unspecified type, use this one. If
        preferred_dtype is None, defaults to dtypes.int64.

    Returns:
      A `RowPartition`.

    Raises:
      ValueError: If `row_splits` is an empty list.
    """
        if not isinstance(validate, bool):
            raise TypeError("validate must have type bool")
        if isinstance(row_splits, (list, tuple)) and not row_splits:
            raise ValueError("row_splits tensor may not be empty.")
        if isinstance(row_splits, tensor_spec.TensorSpec):
            return cls(row_splits=row_splits,
                       internal=_row_partition_factory_key)

        with ops.name_scope(None, "RowPartitionFromRowSplits", [row_splits]):
            row_splits = cls._convert_row_partition(row_splits, "row_splits",
                                                    preferred_dtype)
            row_splits.shape.assert_has_rank(1)

            if validate:
                msg = "Arguments to from_row_splits do not form a valid RaggedTensor:"
                checks = [
                    check_ops.assert_rank(row_splits,
                                          1,
                                          message=(msg + "rank")),
                    _assert_zero(row_splits[0], message=(msg + "zero")),
                    _assert_monotonic_increasing(row_splits,
                                                 message=(msg + "monotonic")),
                ]
                row_splits = control_flow_ops.with_dependencies(
                    checks, row_splits)

            return cls(row_splits=row_splits,
                       internal=_row_partition_factory_key)
Ejemplo n.º 34
0
 def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
   tensor = constant_op.constant(1, name="my_tensor")
   desired_rank = 1
   with self.assertRaisesRegexp(ValueError,
                                "fail.*must have rank 1"):
     with ops.control_dependencies(
         [check_ops.assert_rank(
             tensor, desired_rank, message="fail")]):
       self.evaluate(array_ops.identity(tensor))
    def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                       validate_args, name):
        """Helper to __init__ which ensures override batch/event_shape are valid."""
        if override_shape is None:
            override_shape = []

        override_shape = ops.convert_to_tensor(override_shape,
                                               dtype=dtypes.int32,
                                               name=name)

        if not override_shape.dtype.is_integer:
            raise TypeError("shape override must be an integer")

        override_is_scalar = _is_scalar_from_shape(override_shape)
        if tensor_util.constant_value(override_is_scalar):
            return self._empty

        dynamic_assertions = []

        if override_shape.get_shape().ndims is not None:
            if override_shape.get_shape().ndims != 1:
                raise ValueError("shape override must be a vector")
        elif validate_args:
            dynamic_assertions += [
                check_ops.assert_rank(
                    override_shape,
                    1,
                    message="shape override must be a vector")
            ]

        if tensor_util.constant_value(override_shape) is not None:
            if any(s <= 0 for s in tensor_util.constant_value(override_shape)):
                raise ValueError("shape override must have positive elements")
        elif validate_args:
            dynamic_assertions += [
                check_ops.assert_positive(
                    override_shape,
                    message="shape override must have positive elements")
            ]

        is_both_nonscalar = _logical_and(_logical_not(base_is_scalar),
                                         _logical_not(override_is_scalar))
        if tensor_util.constant_value(is_both_nonscalar) is not None:
            if tensor_util.constant_value(is_both_nonscalar):
                raise ValueError("base distribution not scalar")
        elif validate_args:
            dynamic_assertions += [
                check_ops.assert_equal(is_both_nonscalar,
                                       False,
                                       message="base distribution not scalar")
            ]

        if not dynamic_assertions:
            return override_shape
        return control_flow_ops.with_dependencies(dynamic_assertions,
                                                  override_shape)
Ejemplo n.º 36
0
 def test_raises_if_rank_is_not_integer_dynamic(self):
   with self.test_session():
     tensor = constant_op.constant(
         [1, 2], dtype=dtypes.float32, name="my_tensor")
     rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
     with self.assertRaisesRegexp(TypeError,
                                  "must be of type <dtype: 'int32'>"):
       with ops.control_dependencies(
           [check_ops.assert_rank(tensor, rank_tensor)]):
         array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})
Ejemplo n.º 37
0
 def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
   with self.test_session():
     tensor = constant_op.constant(1, name="my_tensor")
     desired_rank = 1
     with self.assertRaisesRegexp(ValueError,
                                  "fail.*my_tensor.*must have rank 1"):
       with ops.control_dependencies(
           [check_ops.assert_rank(
               tensor, desired_rank, message="fail")]):
         array_ops.identity(tensor).eval()
  def _check_shapes_dynamic(self, operator, v, diag):
    """Return (v, diag) with Assert dependencies, which check shape."""
    checks = []
    with ops.name_scope("check_shapes", values=[operator, v, diag]):
      s_v = array_ops.shape(v)
      r_op = operator.rank()
      r_v = array_ops.rank(v)
      if diag is not None:
        s_d = array_ops.shape(diag)
        r_d = array_ops.rank(diag)

      # Check tensor rank.
      checks.append(check_ops.assert_rank(
          v, r_op, message="v is not the same rank as operator."))
      if diag is not None:
        checks.append(check_ops.assert_rank(
            diag, r_op - 1, message="diag is not the same rank as operator."))

      # Check batch shape
      checks.append(check_ops.assert_equal(
          operator.batch_shape(), array_ops.strided_slice(s_v, [0], [r_v - 2]),
          message="v does not have same batch shape as operator."))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            operator.batch_shape(), array_ops.strided_slice(
                s_d, [0], [r_d - 1]),
            message="diag does not have same batch shape as operator."))

      # Check event shape
      checks.append(check_ops.assert_equal(
          operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2),
          message="v does not have same event shape as operator."))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1),
            message="diag does not have same event shape as v."))

      v = control_flow_ops.with_dependencies(checks, v)
      if diag is not None:
        diag = control_flow_ops.with_dependencies(checks, diag)
      return v, diag
Ejemplo n.º 39
0
    def from_row_lengths(cls,
                         row_lengths,
                         name=None,
                         validate=True,
                         preferred_dtype=None):
        """Creates a `RowPartition` with rows partitioned by `row_lengths`.

    A `RaggedTensor` constructed with this corresponds with the python list
     defined by:

    ```python
    result = [[values.pop(0) for i in range(length)]
              for length in row_lengths]
    ```

    Args:
      row_lengths: A 1-D integer tensor with shape `[nrows]`.  Must be
        nonnegative.
      name: A name prefix for the RowPartition (optional).
      validate: If true, then use assertions to check that the arguments form a
        valid `RowPartition`.
      preferred_dtype: If row_lengths has an unspecified type, use this one. If
        preferred_dtype is None, defaults to dtypes.int64.

    Returns:
      A `RowPartition`.
    """
        if not isinstance(validate, bool):
            raise TypeError("validate must have type bool")
        with ops.name_scope(name, "RowPartitionFromRowLengths", [row_lengths]):
            row_lengths = cls._convert_row_partition(row_lengths,
                                                     "row_lengths",
                                                     preferred_dtype)
            row_lengths.shape.assert_has_rank(1)

            if validate:
                msg = "Arguments to from_row_lengths do not form a valid RowPartition"
                checks = [
                    check_ops.assert_rank(row_lengths, 1, message=msg),
                    check_ops.assert_non_negative(row_lengths, message=msg),
                ]
                row_lengths = control_flow_ops.with_dependencies(
                    checks, row_lengths)

            row_limits = math_ops.cumsum(row_lengths)
            row_splits = array_ops.concat([[0], row_limits], axis=0)
            return cls(row_splits=row_splits,
                       cached_row_lengths=row_lengths,
                       internal=True)
Ejemplo n.º 40
0
 def _assert_non_negative_int32_scalar(self, x):
     """Helper which ensures that input is a non-negative, int32, scalar."""
     x = ops.convert_to_tensor(x, name="x")
     if x.dtype.base_dtype != dtypes.int32.base_dtype:
         raise TypeError("%s.dtype=%s is not %s" % (x.name, x.dtype, dtypes.int32))
     x_value_static = tensor_util.constant_value(x)
     if x.get_shape().ndims is not None and x_value_static is not None:
         if x.get_shape().ndims != 0:
             raise ValueError("%s.ndims=%d is not 0 (scalar)" % (x.name, x.get_shape().ndims))
         if x_value_static < 0:
             raise ValueError("%s.value=%d cannot be negative" % (x.name, x_value_static))
         return x
     if self.validate_args:
         x = control_flow_ops.with_dependencies([check_ops.assert_rank(x, 0), check_ops.assert_non_negative(x)], x)
     return x
  def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                     validate_args, name):
    """Helper to __init__ which ensures override batch/event_shape are valid."""
    if override_shape is None:
      override_shape = []

    override_shape = ops.convert_to_tensor(override_shape, dtype=dtypes.int32,
                                           name=name)

    if not override_shape.dtype.is_integer:
      raise TypeError("shape override must be an integer")

    override_is_scalar = _is_scalar_from_shape(override_shape)
    if tensor_util.constant_value(override_is_scalar):
      return self._empty

    dynamic_assertions = []

    if override_shape.get_shape().ndims is not None:
      if override_shape.get_shape().ndims != 1:
        raise ValueError("shape override must be a vector")
    elif validate_args:
      dynamic_assertions += [check_ops.assert_rank(
          override_shape, 1,
          message="shape override must be a vector")]

    if tensor_util.constant_value(override_shape) is not None:
      if any(s <= 0 for s in tensor_util.constant_value(override_shape)):
        raise ValueError("shape override must have positive elements")
    elif validate_args:
      dynamic_assertions += [check_ops.assert_positive(
          override_shape,
          message="shape override must have positive elements")]

    is_both_nonscalar = _logical_and(_logical_not(base_is_scalar),
                                     _logical_not(override_is_scalar))
    if tensor_util.constant_value(is_both_nonscalar) is not None:
      if tensor_util.constant_value(is_both_nonscalar):
        raise ValueError("base distribution not scalar")
    elif validate_args:
      dynamic_assertions += [check_ops.assert_equal(
          is_both_nonscalar, False,
          message="base distribution not scalar")]

    if not dynamic_assertions:
      return override_shape
    return control_flow_ops.with_dependencies(
        dynamic_assertions, override_shape)
Ejemplo n.º 42
0
    def from_row_lengths(cls,
                         row_lengths,
                         validate=True,
                         preferred_dtype=None):
        """Creates a `RowPartition` with rows partitioned by `row_lengths`.

    This `RowPartition` divides a sequence `values` into rows by indicating
    the length of each row:

    ```python
    partitioned_rows = [[values.pop(0) for _ in range(length)]
                        for length in row_lengths]
    ```

    Args:
      row_lengths: A 1-D integer tensor with shape `[nrows]`.  Must be
        nonnegative.
      validate: If true, then use assertions to check that the arguments form a
        valid `RowPartition`.
      preferred_dtype: If row_lengths has an unspecified type, use this one. If
        preferred_dtype is None, defaults to dtypes.int64.

    Returns:
      A `RowPartition`.
    """
        if not isinstance(validate, bool):
            raise TypeError("validate must have type bool")
        with ops.name_scope(None, "RowPartitionFromRowLengths", [row_lengths]):
            row_lengths = cls._convert_row_partition(row_lengths,
                                                     "row_lengths",
                                                     preferred_dtype)
            row_lengths.shape.assert_has_rank(1)

            if validate:
                msg = "Arguments to from_row_lengths do not form a valid RowPartition"
                checks = [
                    check_ops.assert_rank(row_lengths, 1, message=msg),
                    check_ops.assert_non_negative(row_lengths, message=msg),
                ]
                row_lengths = control_flow_ops.with_dependencies(
                    checks, row_lengths)

            row_limits = math_ops.cumsum(row_lengths)
            row_splits = array_ops.concat([[0], row_limits], axis=0)
            return cls(row_splits=row_splits,
                       row_lengths=row_lengths,
                       internal=_row_partition_factory_key)
Ejemplo n.º 43
0
    def from_row_starts(cls,
                        row_starts,
                        nvals,
                        name=None,
                        validate=True,
                        preferred_dtype=None):
        """Creates a `RowPartition` with rows partitioned by `row_starts`.

    Equivalent to: `from_row_splits(concat([row_starts, nvals]))`.

    Args:
      row_starts: A 1-D integer tensor with shape `[nrows]`.  Must be
        nonnegative and sorted in ascending order.  If `nrows>0`, then
        `row_starts[0]` must be zero.
      nvals: A scalar tensor indicating the number of values.
      name: A name prefix for the RowPartition (optional).
      validate: If true, then use assertions to check that the arguments form a
        valid `RowPartition`.
      preferred_dtype: If row_limits has an unspecified type, use this one. If
        preferred_dtype is None, defaults to dtypes.int64.

    Returns:
      A `RowPartition`.
    """
        if not isinstance(validate, bool):
            raise TypeError("validate must have type bool")
        with ops.name_scope(name, "RowPartitionFromRowStarts", [row_starts]):
            row_starts = cls._convert_row_partition(row_starts, "row_starts",
                                                    preferred_dtype)
            row_starts.shape.assert_has_rank(1)
            nvals = math_ops.cast(nvals, row_starts.dtype)
            if validate:
                msg = "Arguments to from_row_starts do not form a valid RaggedTensor"
                checks = [
                    check_ops.assert_rank(row_starts, 1, message=msg),
                    _assert_zero(row_starts[:1], message=msg),
                    _assert_monotonic_increasing(row_starts, message=msg),
                    check_ops.assert_less_equal(row_starts[-1:],
                                                nvals,
                                                message=msg),
                ]
                row_starts = control_flow_ops.with_dependencies(
                    checks, row_starts)

            row_splits = array_ops.concat([row_starts, [nvals]], axis=0)
            return cls(row_splits=row_splits, internal=True)
Ejemplo n.º 44
0
def dict_to_state_tuple(input_dict, cell):
    """Reconstructs nested `state` from a dict containing state `Tensor`s.

  Args:
    input_dict: A dict of `Tensor`s.
    cell: An instance of `RNNCell`.
  Returns:
    If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n`
    where `n` is the number of nested entries in `cell.state_size`, this
    function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size`
    is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested
    tuple.
  Raises:
    ValueError: State is partially specified. The `input_dict` must contain
      values for all state components or none at all.
  """
    flat_state_sizes = nest.flatten(cell.state_size)
    state_tensors = []
    with ops.name_scope('dict_to_state_tuple'):
        for i, state_size in enumerate(flat_state_sizes):
            state_name = _get_state_name(i)
            state_tensor = input_dict.get(state_name)
            if state_tensor is not None:
                rank_check = check_ops.assert_rank(
                    state_tensor, 2, name='check_state_{}_rank'.format(i))
                shape_check = check_ops.assert_equal(
                    array_ops.shape(state_tensor)[1],
                    state_size,
                    name='check_state_{}_shape'.format(i))
                with ops.control_dependencies([rank_check, shape_check]):
                    state_tensor = array_ops.identity(state_tensor,
                                                      name=state_name)
                state_tensors.append(state_tensor)
        if not state_tensors:
            return None
        elif len(state_tensors) == len(flat_state_sizes):
            dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool)
            return nest.pack_sequence_as(dummy_state, state_tensors)
        else:
            raise ValueError(
                'RNN state was partially specified.'
                'Expected zero or {} state Tensors; got {}'.format(
                    len(flat_state_sizes), len(state_tensors)))
Ejemplo n.º 45
0
 def _assert_non_negative_int32_scalar(self, x):
   """Helper which ensures that input is a non-negative, int32, scalar."""
   x = ops.convert_to_tensor(x, name="x")
   if x.dtype.base_dtype != dtypes.int32.base_dtype:
     raise TypeError("%s.dtype=%s is not %s" % (x.name, x.dtype, dtypes.int32))
   x_value_static = tensor_util.constant_value(x)
   if x.get_shape().ndims is not None and x_value_static is not None:
     if x.get_shape().ndims != 0:
       raise ValueError("%s.ndims=%d is not 0 (scalar)" %
                        (x.name, x.get_shape().ndims))
     if x_value_static < 0:
       raise ValueError("%s.value=%d cannot be negative" %
                        (x.name, x_value_static))
     return x
   if self.validate_args:
     x = control_flow_ops.with_dependencies([
         check_ops.assert_rank(x, 0),
         check_ops.assert_non_negative(x)], x)
   return x
def dict_to_state_tuple(input_dict, cell):
  """Reconstructs nested `state` from a dict containing state `Tensor`s.

  Args:
    input_dict: A dict of `Tensor`s.
    cell: An instance of `RNNCell`.
  Returns:
    If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n`
    where `n` is the number of nested entries in `cell.state_size`, this
    function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size`
    is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested
    tuple.
  Raises:
    ValueError: State is partially specified. The `input_dict` must contain
      values for all state components or none at all.
  """
  flat_state_sizes = nest.flatten(cell.state_size)
  state_tensors = []
  with ops.name_scope('dict_to_state_tuple'):
    for i, state_size in enumerate(flat_state_sizes):
      state_name = _get_state_name(i)
      state_tensor = input_dict.get(state_name)
      if state_tensor is not None:
        rank_check = check_ops.assert_rank(
            state_tensor, 2, name='check_state_{}_rank'.format(i))
        shape_check = check_ops.assert_equal(
            array_ops.shape(state_tensor)[1],
            state_size,
            name='check_state_{}_shape'.format(i))
        with ops.control_dependencies([rank_check, shape_check]):
          state_tensor = array_ops.identity(state_tensor, name=state_name)
        state_tensors.append(state_tensor)
    if not state_tensors:
      return None
    elif len(state_tensors) == len(flat_state_sizes):
      dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool)
      return nest.pack_sequence_as(dummy_state, state_tensors)
    else:
      raise ValueError(
          'RNN state was partially specified.'
          'Expected zero or {} state Tensors; got {}'.
          format(len(flat_state_sizes), len(state_tensors)))
Ejemplo n.º 47
0
def _check_logits(logits, expected_logits_dimension):
  """Check logits type and shape."""
  with ops.name_scope(None, 'logits', (logits,)) as scope:
    logits = math_ops.to_float(logits)
    logits_shape = array_ops.shape(logits)
    assert_rank = check_ops.assert_rank(
        logits, 2, data=[logits_shape],
        message='logits shape must be [batch_size, logits_dimension]')
    with ops.control_dependencies([assert_rank]):
      static_shape = logits.shape
      if static_shape is not None:
        dim1 = static_shape[1]
        if (dim1 is not None) and (dim1 != expected_logits_dimension):
          raise ValueError(
              'logits shape must be [batch_size, logits_dimension], got %s.' %
              (static_shape,))
      assert_dimension = check_ops.assert_equal(
          expected_logits_dimension, logits_shape[1], data=[logits_shape],
          message='logits shape must be [batch_size, logits_dimension]')
      with ops.control_dependencies([assert_dimension]):
        return array_ops.identity(logits, name=scope)
Ejemplo n.º 48
0
def _check_and_reshape_dense_labels(labels, expected_labels_dimension):
  """Checks dense labels type and shape and reshapes to 2D Tensor."""
  if labels is None:
    raise ValueError(
        'You must provide a labels Tensor. Given: None. '
        'Suggested troubleshooting steps: Check that your data contain '
        'your label feature. Check that your input_fn properly parses and '
        'returns labels.')
  with ops.name_scope(None, 'labels', (labels,)) as scope:
    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
    if isinstance(labels, sparse_tensor.SparseTensor):
      raise ValueError(
          'SparseTensor labels are not supported. '
          'labels must be a Tensor of shape [batch_size, %s]. '
          'Suggested Fix (1): Check the label feature in your data. '
          'Each example must contain %s value(s). If not, your choice of label '
          'was probably incorrect. '
          'Suggested Fix (2): In your input_fn, use '
          'tf.sparse_tensor_to_dense() to turn labels into a Tensor.'
          '' % (expected_labels_dimension, expected_labels_dimension))
    labels = _maybe_expand_dim(labels)
    labels_shape = array_ops.shape(labels)
    err_msg = 'labels shape must be [batch_size, {}]'.format(
        expected_labels_dimension)
    assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
    with ops.control_dependencies([assert_rank]):
      static_shape = labels.shape
      if static_shape is not None:
        dim1 = static_shape[1]
        if (dim1 is not None) and (dim1 != expected_labels_dimension):
          raise ValueError(
              'Mismatched label shape. '
              'Classifier configured with n_classes=%s.  Received %s. '
              'Suggested Fix: check your n_classes argument to the estimator '
              'and/or the shape of your label.' %
              (expected_labels_dimension, dim1))
      assert_dimension = check_ops.assert_equal(
          expected_labels_dimension, labels_shape[1], message=err_msg)
      with ops.control_dependencies([assert_dimension]):
        return array_ops.identity(labels, name=scope)
Ejemplo n.º 49
0
def _check_labels(labels, expected_labels_dimension):
  """Check labels type and shape."""
  with ops.name_scope(None, 'labels', (labels,)) as scope:
    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
    if isinstance(labels, sparse_tensor.SparseTensor):
      raise ValueError('SparseTensor labels are not supported.')
    labels_shape = array_ops.shape(labels)
    err_msg = 'labels shape must be [batch_size, {}]'.format(
        expected_labels_dimension)
    assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
    with ops.control_dependencies([assert_rank]):
      static_shape = labels.shape
      if static_shape is not None:
        dim1 = static_shape[1]
        if (dim1 is not None) and (dim1 != expected_labels_dimension):
          raise ValueError(
              'labels shape must be [batch_size, labels_dimension], got %s.' %
              (static_shape,))
      assert_dimension = check_ops.assert_equal(
          expected_labels_dimension, labels_shape[1], message=err_msg)
      with ops.control_dependencies([assert_dimension]):
        return array_ops.identity(labels, name=scope)
Ejemplo n.º 50
0
  def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                     validate_args):
    """Helper to __init__ which ensures override batch/event_shape are valid."""
    if not override_shape.dtype.is_integer:
      raise TypeError("shape override must be an integer")

    if override_shape.get_shape().ndims is not None:
      if override_shape.get_shape().ndims != 1:
        raise ValueError("shape override must be a vector")
    elif validate_args:
      is_vector = check_ops.assert_rank(
          override_shape, 1,
          message="shape override must be a vector")
      override_shape = control_flow_ops.with_dependencies(
          [is_vector], override_shape)

    if override_shape.get_shape().is_fully_defined():
      if any(s <= 0 for s in override_shape.get_shape().as_list()):
        raise ValueError("shape override must have positive elements")
    elif validate_args:
      is_positive = check_ops.assert_positive(
          override_shape,
          message="shape override must have positive elements")
      override_shape = control_flow_ops.with_dependencies(
          [is_positive], override_shape)

    if tensor_util.constant_value(base_is_scalar) is not None:
      if not tensor_util.constant_value(base_is_scalar):
        raise ValueError("shape override requires scalar distribution.")
    elif validate_args:
      is_scalar = check_ops.assert_equal(
          base_is_scalar, True,
          message="shape override requires scalar distribution.")
      override_shape = control_flow_ops.with_dependencies(
          [is_scalar], override_shape)

    return override_shape
Ejemplo n.º 51
0
def percentile(x,
               q,
               axis=None,
               interpolation=None,
               keep_dims=False,
               validate_args=False,
               name=None):
  """Compute the `q`-th percentile of `x`.

  Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the
  way from the minimum to the maximum in a sorted copy of `x`.

  The values and distances of the two nearest neighbors as well as the
  `interpolation` parameter will determine the percentile if the normalized
  ranking does not match the location of `q` exactly.

  This function is the same as the median if `q = 50`, the same as the minimum
  if `q = 0` and the same as the maximum if `q = 100`.


  ```python
  # Get 30th percentile with default ('nearest') interpolation.
  x = [1., 2., 3., 4.]
  percentile(x, q=30.)
  ==> 2.0

  # Get 30th percentile with 'lower' interpolation
  x = [1., 2., 3., 4.]
  percentile(x, q=30., interpolation='lower')
  ==> 1.0

  # Get 100th percentile (maximum).  By default, this is computed over every dim
  x = [[1., 2.]
       [3., 4.]]
  percentile(x, q=100.)
  ==> 4.0

  # Treat the leading dim as indexing samples, and find the 100th quantile (max)
  # over all such samples.
  x = [[1., 2.]
       [3., 4.]]
  percentile(x, q=100., axis=[0])
  ==> [3., 4.]
  ```

  Compare to `numpy.percentile`.

  Args:
    x:  Floating point `N-D` `Tensor` with `N > 0`.  If `axis` is not `None`,
      `x` must have statically known number of dimensions.
    q:  Scalar `Tensor` in `[0, 100]`. The percentile.
    axis:  Optional `0-D` or `1-D` integer `Tensor` with constant values.
      The axis that hold independent samples over which to return the desired
      percentile.  If `None` (the default), treat every dimension as a sample
      dimension, returning a scalar.
    interpolation : {"lower", "higher", "nearest"}.  Default: "nearest"
      This optional parameter specifies the interpolation method to
      use when the desired quantile lies between two data points `i < j`:
        * lower: `i`.
        * higher: `j`.
        * nearest: `i` or `j`, whichever is nearest.
    keep_dims:  Python `bool`. If `True`, the last dimension is kept with size 1
      If `False`, the last dimension is removed from the output shape.
    validate_args:  Whether to add runtime checks of argument validity.
      If False, and arguments are incorrect, correct behavior is not guaranteed.
    name:  A Python string name to give this `Op`.  Default is "percentile"

  Returns:
    A `(N - len(axis))` dimensional `Tensor` of same dtype as `x`, or, if
      `axis` is `None`, a scalar.

  Raises:
    ValueError:  If argument 'interpolation' is not an allowed type.
  """
  name = name or "percentile"
  allowed_interpolations = {"lower", "higher", "nearest"}

  if interpolation is None:
    interpolation = "nearest"
  else:
    if interpolation not in allowed_interpolations:
      raise ValueError("Argument 'interpolation' must be in %s.  Found %s" %
                       (allowed_interpolations, interpolation))

  with ops.name_scope(name, [x, q]):
    x = ops.convert_to_tensor(x, name="x")
    # Double is needed here and below, else we get the wrong index if the array
    # is huge along axis.
    q = math_ops.to_double(q, name="q")
    _get_static_ndims(q, expect_ndims=0)

    if validate_args:
      q = control_flow_ops.with_dependencies([
          check_ops.assert_rank(q, 0),
          check_ops.assert_greater_equal(q, math_ops.to_double(0.)),
          check_ops.assert_less_equal(q, math_ops.to_double(100.))
      ], q)

    if axis is None:
      y = array_ops.reshape(x, [-1])
    else:
      axis = ops.convert_to_tensor(axis, name="axis")
      check_ops.assert_integer(axis)
      axis_ndims = _get_static_ndims(
          axis, expect_static=True, expect_ndims_no_more_than=1)
      axis_const = tensor_util.constant_value(axis)
      if axis_const is None:
        raise ValueError(
            "Expected argument 'axis' to be statically available.  Found: %s" %
            axis)
      axis = axis_const
      if axis_ndims == 0:
        axis = [axis]
      axis = [int(a) for a in axis]
      x_ndims = _get_static_ndims(
          x, expect_static=True, expect_ndims_at_least=1)
      axis = _make_static_axis_non_negative(axis, x_ndims)
      y = _move_dims_to_flat_end(x, axis, x_ndims)

    frac_at_q_or_above = 1. - q / 100.
    d = math_ops.to_double(array_ops.shape(y)[-1])

    if interpolation == "lower":
      index = math_ops.ceil((d - 1) * frac_at_q_or_above)
    elif interpolation == "higher":
      index = math_ops.floor((d - 1) * frac_at_q_or_above)
    elif interpolation == "nearest":
      index = math_ops.round((d - 1) * frac_at_q_or_above)

    # If d is gigantic, then we would have d == d - 1, even in double... So
    # let's use max/min to avoid out of bounds errors.
    d = array_ops.shape(y)[-1]
    # d - 1 will be distinct from d in int32.
    index = clip_ops.clip_by_value(math_ops.to_int32(index), 0, d - 1)

    # Sort everything, not just the top 'k' entries, which allows multiple calls
    # to sort only once (under the hood) and use CSE.
    sorted_y = _sort_tensor(y)

    # result.shape = B
    result = sorted_y[..., index]
    result.set_shape(y.get_shape()[:-1])

    if keep_dims:
      if axis is None:
        # ones_vec = [1, 1,..., 1], total length = len(S) + len(B).
        ones_vec = array_ops.ones(
            shape=[_get_best_effort_ndims(x)], dtype=dtypes.int32)
        result *= array_ops.ones(ones_vec, dtype=x.dtype)
      else:
        result = _insert_back_keep_dims(result, axis)

    return result
Ejemplo n.º 52
0
 def test_raises_if_rank_is_not_integer_static(self):
   with self.test_session():
     tensor = constant_op.constant([1, 2], name="my_tensor")
     with self.assertRaisesRegexp(TypeError,
                                  "must be of type <dtype: 'int32'>"):
       check_ops.assert_rank(tensor, .5)
Ejemplo n.º 53
0
def renyi_alpha(step,
                decay_time,
                alpha_min,
                alpha_max=0.99999,
                name='renyi_alpha'):
  r"""Exponentially decaying `Tensor` appropriate for Renyi ratios.

  When minimizing the Renyi divergence for `0 <= alpha < 1` (or maximizing the
  Renyi equivalent of elbo) in high dimensions, it is not uncommon to experience
  `NaN` and `inf` values when `alpha` is far from `1`.

  For that reason, it is often desirable to start the optimization with `alpha`
  very close to 1, and reduce it to a final `alpha_min` according to some
  schedule.  The user may even want to optimize using `elbo_ratio` for
  some fixed time before switching to Renyi based methods.

  This `Op` returns an `alpha` decaying exponentially with step:

  ```
  s(step) = (exp{step / decay_time} - 1) / (e - 1)
  t(s) = max(0, min(s, 1)),  (smooth growth from 0 to 1)
  alpha(t) = (1 - t) alpha_min + t alpha_max
  ```

  Args:
    step:  Non-negative scalar `Tensor`.  Typically the global step or an
      offset version thereof.
    decay_time:  Positive scalar `Tensor`.
    alpha_min:  `float` or `double` `Tensor`.
      The minimal, final value of `alpha`, achieved when `step >= decay_time`
    alpha_max:  `Tensor` of same `dtype` as `alpha_min`.
      The maximal, beginning value of `alpha`, achieved when `step == 0`
    name:  A name to give this `Op`.

  Returns:
    alpha:  A `Tensor` of same `dtype` as `alpha_min`.
  """
  with ops.name_scope(name, values=[step, decay_time, alpha_min, alpha_max]):
    alpha_min = ops.convert_to_tensor(alpha_min, name='alpha_min')
    dtype = alpha_min.dtype

    alpha_max = ops.convert_to_tensor(alpha_max, dtype=dtype, name='alpha_max')
    decay_time = math_ops.cast(decay_time, dtype)
    step = math_ops.cast(step, dtype)

    check_scalars = [
        check_ops.assert_rank(step, 0, message='step must be scalar'),
        check_ops.assert_rank(
            decay_time, 0, message='decay_time must be scalar'),
        check_ops.assert_rank(alpha_min, 0, message='alpha_min must be scalar'),
        check_ops.assert_rank(alpha_max, 0, message='alpha_max must be scalar'),
    ]
    check_sign = [
        check_ops.assert_non_negative(
            step, message='step must be non-negative'),
        check_ops.assert_positive(
            decay_time, message='decay_time must be positive'),
    ]

    with ops.control_dependencies(check_scalars + check_sign):
      theta = (math_ops.exp(step / decay_time) - 1.) / (math.e - 1.)
      theta = math_ops.minimum(math_ops.maximum(theta, 0.), 1.)
      return alpha_max * (1. - theta) + alpha_min * theta
Ejemplo n.º 54
0
 def test_raises_if_rank_is_not_scalar_static(self):
   with self.test_session():
     tensor = constant_op.constant([1, 2], name="my_tensor")
     with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"):
       check_ops.assert_rank(tensor, np.array([], dtype=np.int32))