Example #1
0
 def testMatmulError(self):
   with self.assertRaisesRegex(ValueError, r''):
     np_math_ops.matmul(
         np_array_ops.ones([], np.int32), np_array_ops.ones([2, 3], np.int32))
   with self.assertRaisesRegex(ValueError, r''):
     np_math_ops.matmul(
         np_array_ops.ones([2, 3], np.int32), np_array_ops.ones([], np.int32))
  def testSize(self):

    def run_test(arr, axis=None):
      onp_arr = np.array(arr)
      self.assertEqual(np_array_ops.size(arr, axis), np.size(onp_arr, axis))

    run_test(np_array_ops.array([1]))
    run_test(np_array_ops.array([1, 2, 3, 4, 5]))
    run_test(np_array_ops.ones((2, 3, 2)))
    run_test(np_array_ops.ones((3, 2)))
    run_test(np_array_ops.zeros((5, 6, 7)))
    run_test(1)
    run_test(np_array_ops.ones((3, 2, 1)))
    run_test(constant_op.constant(5))
    run_test(constant_op.constant([1, 1, 1]))
    self.assertRaises(NotImplementedError, np_array_ops.size, np.ones((2, 2)),
                      1)

    @def_function.function(input_signature=[
        tensor_spec.TensorSpec(dtype=dtypes.float64, shape=None)])
    def f(arr):
      arr = np_array_ops.asarray(arr)
      return np_array_ops.size(arr)

    self.assertEqual(f(np_array_ops.ones((3, 2))).numpy(), 6)
Example #3
0
  def testOnes(self):
    for s in self.all_shapes:
      actual = np_array_ops.ones(s)
      expected = np.ones(s)
      msg = 'shape: {}'.format(s)
      self.match(actual, expected, msg)

    for s, t in itertools.product(self.all_shapes, self.all_types):
      actual = np_array_ops.ones(s, t)
      expected = np.ones(s, t)
      msg = 'shape: {}, dtype: {}'.format(s, t)
      self.match(actual, expected, msg)
Example #4
0
 def testIndexedSlices(self):
   dtype = dtypes.int64
   iss = indexed_slices.IndexedSlices(
       values=np_array_ops.ones([2, 3], dtype=dtype),
       indices=constant_op.constant([1, 9]),
       dense_shape=[10, 3])
   a = np_array_ops.array(iss, copy=False)
   expected = array_ops.scatter_nd([[1], [9]],
                                   array_ops.ones([2, 3], dtype=dtype),
                                   [10, 3])
   self.assertAllEqual(expected, a)