def test_make_empty_list_structure_for_element_type_spec_w_tuple_dict(self): type_spec = computation_types.to_type( [tf.int32, [('a', tf.bool), ('b', tf.float32)]]) result = tensorflow_utils.make_empty_list_structure_for_element_type_spec( type_spec) self.assertEqual( str(result), '([], OrderedDict([(\'a\', []), (\'b\', [])]))')
def _test_list_structure(self, type_spec, elements, expected_output_str): result = tensorflow_utils.make_empty_list_structure_for_element_type_spec( type_spec) for element_value in elements: tensorflow_utils.append_to_list_structure_for_element_type_spec( result, element_value, type_spec) result = (tensorflow_utils.replace_empty_leaf_lists_with_numpy_arrays( result, type_spec)) self.assertEqual( str(result).replace(' ', ''), expected_output_str.replace(' ', ''))
def _test_list_structure(self, type_spec, elements, expected_output_str): structure = tensorflow_utils.make_empty_list_structure_for_element_type_spec( type_spec) for element_value in elements: tensorflow_utils.append_to_list_structure_for_element_type_spec( structure, element_value, type_spec) structure = ( tensorflow_utils .to_tensor_slices_from_list_structure_for_element_type_spec( structure, type_spec)) self.assertEqual( str(structure).replace(' ', ''), expected_output_str.replace(' ', ''))