def test_make_dummy_element_NamedTupleType(self): tensor1 = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10]) tensor2 = computation_types.TensorType(tf.int32, [10, None, 10]) namedtuple = computation_types.NamedTupleType([('x', tensor1), ('y', tensor2)]) unnamedtuple = computation_types.NamedTupleType([('x', tensor1), ('y', tensor2)]) elem = graph_utils._make_dummy_element_for_type_spec(namedtuple) correct_list = [ np.zeros([1, 10, 1, 10, 10], np.float32), np.zeros([10, 1, 10], np.int32) ] self.assertEqual(len(elem), len(correct_list)) for k in range(len(elem)): self.assertEqual(elem[k].tolist(), correct_list[k].tolist()) unnamed_elem = graph_utils._make_dummy_element_for_type_spec(unnamedtuple) self.assertEqual(len(unnamed_elem), len(correct_list)) for k in range(len(unnamed_elem)): self.assertEqual(unnamed_elem[k].tolist(), correct_list[k].tolist())
def test_make_dummy_element_TensorType(self): type_spec = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10]) elem = graph_utils._make_dummy_element_for_type_spec(type_spec) correct_elem = np.zeros([1, 10, 1, 10, 10], np.float32) self.assertEqual(elem.tolist(), correct_elem.tolist())
def test_make_dummy_element_TensorType(self): type_spec = computation_types.TensorType(tf.float32, [None, 10, None, 10, 10]) elem = graph_utils._make_dummy_element_for_type_spec(type_spec) correct_elem = np.zeros([0, 10, 0, 10, 10], np.float32) self.assertTrue(np.array_equal(elem, correct_elem))