Exemple #1
0
 def test_make_dummy_element_NamedTupleType(self):
   tensor1 = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10])
   tensor2 = computation_types.TensorType(tf.int32, [10, None, 10])
   namedtuple = computation_types.NamedTupleType([('x', tensor1),
                                                  ('y', tensor2)])
   unnamedtuple = computation_types.NamedTupleType([('x', tensor1),
                                                    ('y', tensor2)])
   elem = graph_utils._make_dummy_element_for_type_spec(namedtuple)
   correct_list = [
       np.zeros([1, 10, 1, 10, 10], np.float32),
       np.zeros([10, 1, 10], np.int32)
   ]
   self.assertEqual(len(elem), len(correct_list))
   for k in range(len(elem)):
     self.assertEqual(elem[k].tolist(), correct_list[k].tolist())
   unnamed_elem = graph_utils._make_dummy_element_for_type_spec(unnamedtuple)
   self.assertEqual(len(unnamed_elem), len(correct_list))
   for k in range(len(unnamed_elem)):
     self.assertEqual(unnamed_elem[k].tolist(), correct_list[k].tolist())
Exemple #2
0
 def test_make_dummy_element_TensorType(self):
     type_spec = computation_types.TensorType(tf.float32,
                                              [None, 10, None, 10, 10])
     elem = graph_utils._make_dummy_element_for_type_spec(type_spec)
     correct_elem = np.zeros([1, 10, 1, 10, 10], np.float32)
     self.assertEqual(elem.tolist(), correct_elem.tolist())
Exemple #3
0
 def test_make_dummy_element_TensorType(self):
     type_spec = computation_types.TensorType(tf.float32,
                                              [None, 10, None, 10, 10])
     elem = graph_utils._make_dummy_element_for_type_spec(type_spec)
     correct_elem = np.zeros([0, 10, 0, 10, 10], np.float32)
     self.assertTrue(np.array_equal(elem, correct_elem))