def test_with_two_level_tuple(self): type_signature = computation_types.StructWithPythonType([ ('a', tf.bool), ('b', computation_types.StructWithPythonType([ ('c', computation_types.TensorType(tf.float32)), ('d', computation_types.TensorType(tf.int32, [20])), ], collections.OrderedDict)), ('e', computation_types.StructType([])), ], collections.OrderedDict) dtypes, shapes = type_conversions.type_to_tf_dtypes_and_shapes( type_signature) test.assert_nested_struct_eq(dtypes, { 'a': tf.bool, 'b': { 'c': tf.float32, 'd': tf.int32 }, 'e': (), }) test.assert_nested_struct_eq( shapes, { 'a': tf.TensorShape([]), 'b': { 'c': tf.TensorShape([]), 'd': tf.TensorShape([20]) }, 'e': (), })
def test_with_two_level_tuple(self): dtypes, shapes = type_conversions.type_to_tf_dtypes_and_shapes( collections.OrderedDict([ ('a', tf.bool), ('b', collections.OrderedDict([ ('c', tf.float32), ('d', (tf.int32, [20])), ])), ('e', ()), ])) test.assert_nested_struct_eq(dtypes, { 'a': tf.bool, 'b': { 'c': tf.float32, 'd': tf.int32 }, 'e': (), }) test.assert_nested_struct_eq( shapes, { 'a': tf.TensorShape([]), 'b': { 'c': tf.TensorShape([]), 'd': tf.TensorShape([20]) }, 'e': (), })
def test_with_tensor_triple(self): dtypes, shapes = type_conversions.type_to_tf_dtypes_and_shapes( collections.OrderedDict([('a', (tf.int32, [5])), ('b', tf.bool), ('c', (tf.float32, [3]))])) test.assert_nested_struct_eq(dtypes, { 'a': tf.int32, 'b': tf.bool, 'c': tf.float32 }) test.assert_nested_struct_eq(shapes, { 'a': tf.TensorShape([5]), 'b': tf.TensorShape([]), 'c': tf.TensorShape([3]) })
def test_with_tensor_triple(self): type_signature = computation_types.StructWithPythonType([ ('a', computation_types.TensorType(tf.int32, [5])), ('b', computation_types.TensorType(tf.bool)), ('c', computation_types.TensorType(tf.float32, [3])), ], collections.OrderedDict) dtypes, shapes = type_conversions.type_to_tf_dtypes_and_shapes( type_signature) self.assert_nested_struct_eq(dtypes, { 'a': tf.int32, 'b': tf.bool, 'c': tf.float32 }) self.assert_nested_struct_eq(shapes, { 'a': tf.TensorShape([5]), 'b': tf.TensorShape([]), 'c': tf.TensorShape([3]) })
def test_with_int_vector(self): type_signature = computation_types.TensorType(tf.int32, [10]) dtypes, shapes = type_conversions.type_to_tf_dtypes_and_shapes( type_signature) test.assert_nested_struct_eq(dtypes, tf.int32) test.assert_nested_struct_eq(shapes, tf.TensorShape([10]))
def test_with_int_vector(self): dtypes, shapes = type_conversions.type_to_tf_dtypes_and_shapes( (tf.int32, [10])) test.assert_nested_struct_eq(dtypes, tf.int32) test.assert_nested_struct_eq(shapes, tf.TensorShape([10]))