def test_append_to_list_structure_with_too_many_unnamed_elements(self): type_spec = computation_types.to_type([tf.int32, tf.int32]) result = tuple([[], []]) value = [10, 20, 30] with self.assertRaises(TypeError): tensorflow_utils.append_to_list_structure_for_element_type_spec( result, value, type_spec)
def test_append_to_list_structure_with_too_many_element_keys(self): type_spec = computation_types.to_type([('a', tf.int32), ('b', tf.int32)]) result = collections.OrderedDict([('a', []), ('b', [])]) value = {'a': 10, 'b': 20, 'c': 30} with self.assertRaises(TypeError): tensorflow_utils.append_to_list_structure_for_element_type_spec( result, value, type_spec)
def test_append_to_list_structure_for_element_type_spec_w_tuple_dict(self): type_spec = computation_types.to_type( [tf.int32, [('a', tf.bool), ('b', tf.float32)]]) structure = tuple([[], collections.OrderedDict([('a', []), ('b', [])])]) for value in [[10, {'a': 20, 'b': 30}], (40, [50, 60])]: tensorflow_utils.append_to_list_structure_for_element_type_spec( structure, value, type_spec) self.assertEqual( str(structure), '([10, 40], OrderedDict([(\'a\', [20, 50]), (\'b\', [30, 60])]))')
def _test_list_structure(self, type_spec, elements, expected_output_str): result = tensorflow_utils.make_empty_list_structure_for_element_type_spec( type_spec) for element_value in elements: tensorflow_utils.append_to_list_structure_for_element_type_spec( result, element_value, type_spec) result = (tensorflow_utils.replace_empty_leaf_lists_with_numpy_arrays( result, type_spec)) self.assertEqual( str(result).replace(' ', ''), expected_output_str.replace(' ', ''))
def _test_list_structure(self, type_spec, elements, expected_output_str): structure = tensorflow_utils.make_empty_list_structure_for_element_type_spec( type_spec) for element_value in elements: tensorflow_utils.append_to_list_structure_for_element_type_spec( structure, element_value, type_spec) structure = ( tensorflow_utils .to_tensor_slices_from_list_structure_for_element_type_spec( structure, type_spec)) self.assertEqual( str(structure).replace(' ', ''), expected_output_str.replace(' ', ''))
def test_append_to_list_structure_for_element_type_spec_w_tuple_dict(self): type_spec = computation_types.to_type( [tf.int32, [('a', tf.bool), ('b', tf.float32)]]) result = tuple([[], collections.OrderedDict([('a', []), ('b', [])])]) for value in [[10, {'a': True, 'b': 30}], (40, [False, 60])]: tensorflow_utils.append_to_list_structure_for_element_type_spec( result, value, type_spec) self.assertEqual( str(result), '([<tf.Tensor: shape=(), dtype=int32, numpy=10>, ' '<tf.Tensor: shape=(), dtype=int32, numpy=40>], OrderedDict([(\'a\', [' '<tf.Tensor: shape=(), dtype=bool, numpy=True>, ' '<tf.Tensor: shape=(), dtype=bool, numpy=False>]), (\'b\', [' '<tf.Tensor: shape=(), dtype=float32, numpy=30.0>, ' '<tf.Tensor: shape=(), dtype=float32, numpy=60.0>])]))')