コード例 #1
0
 def testRuntimeError(self, rt_inputs, axis, error, message,
                      ragged_ranks=None):
   rt_inputs = [
       array_ops.placeholder_with_default(rt, shape=None) for rt in rt_inputs
   ]
   concatenated = ragged.concat(rt_inputs, axis)
   with self.test_session():
     self.assertRaisesRegexp(error, message, concatenated.eval)
コード例 #2
0
 def testRuntimeError(self, rt_inputs, axis, error, message,
                      ragged_ranks=None):
   rt_inputs = [
       array_ops.placeholder_with_default(rt, shape=None) for rt in rt_inputs
   ]
   concatenated = ragged.concat(rt_inputs, axis)
   with self.test_session():
     self.assertRaisesRegexp(error, message, concatenated.eval)
コード例 #3
0
    def testSingleTensorInput(self):
        """Tests ragged_concat with a single tensor input.

    Usually, we pass a list of values in for rt_inputs.  However, you can
    also pass in a single value (as with tf.concat), in which case it simply
    returns that tensor.  This test exercises that path.
    """
        rt_inputs = ragged.constant([[1, 2], [3, 4]])
        concatenated = ragged.concat(rt_inputs, 0)
        self.assertRaggedEqual(concatenated, [[1, 2], [3, 4]])
コード例 #4
0
  def testSingleTensorInput(self):
    """Tests ragged_concat with a single tensor input.

    Usually, we pass a list of values in for rt_inputs.  However, you can
    also pass in a single value (as with tf.concat), in which case it simply
    returns that tensor.  This test exercises that path.
    """
    rt_inputs = ragged.constant([[1, 2], [3, 4]])
    concatenated = ragged.concat(rt_inputs, 0)
    with self.test_session():
      self.assertEqual(concatenated.eval().tolist(), [[1, 2], [3, 4]])
コード例 #5
0
 def testRuntimeError(self,
                      rt_inputs,
                      axis,
                      error,
                      message,
                      ragged_ranks=None):
     if context.executing_eagerly():
         return
     rt_inputs = [
         array_ops.placeholder_with_default(rt, shape=None)
         for rt in rt_inputs
     ]
     concatenated = ragged.concat(rt_inputs, axis)
     with self.assertRaisesRegexp(error, message):
         self.evaluate(concatenated)
コード例 #6
0
 def testRaggedConcat(self,
                      descr,
                      rt_inputs,
                      axis,
                      expected,
                      ragged_ranks=None,
                      expected_ragged_rank=None,
                      expected_shape=None):
     rt_inputs = self._rt_inputs_to_tensors(rt_inputs, ragged_ranks)
     concatenated = ragged.concat(rt_inputs, axis)
     if expected_ragged_rank is not None:
         self.assertEqual(concatenated.ragged_rank, expected_ragged_rank)
     if expected_shape is not None:
         self.assertEqual(concatenated.shape.as_list(), expected_shape)
     self.assertRaggedEqual(concatenated, expected)
コード例 #7
0
 def testRaggedConcat(self,
                      descr,
                      rt_inputs,
                      axis,
                      expected,
                      ragged_ranks=None,
                      expected_ragged_rank=None,
                      expected_shape=None):
   rt_inputs = self._rt_inputs_to_tensors(rt_inputs, ragged_ranks)
   concatenated = ragged.concat(rt_inputs, axis)
   if expected_ragged_rank is not None:
     self.assertEqual(concatenated.ragged_rank, expected_ragged_rank)
   if expected_shape is not None:
     self.assertEqual(concatenated.shape.as_list(), expected_shape)
   with self.test_session():
     self.assertEqual(concatenated.eval().tolist(), expected)