def test_returns_computation(self, py_fn, type_signature, arg, expected_result): proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( py_fn, type_signature) self.assertIsInstance(proto, pb.Computation) actual_result = tensorflow_computation_test_utils.run_tensorflow(proto, arg) self.assertEqual(actual_result, expected_result)
def test_transform_compiled_computation_semantic_equivalence(self): tuple_type = computation_types.TensorType(tf.int32) compiled_computation = building_block_factory.create_compiled_identity( tuple_type) config = tf.compat.v1.ConfigProto() tf_optimizer = compiled_computation_transformations.TensorFlowOptimizer( config) transformed_comp, mutated = tf_optimizer.transform( compiled_computation) self.assertTrue(mutated) self.assertIsInstance(transformed_comp, building_blocks.CompiledComputation) zero_before_transform = tensorflow_computation_test_utils.run_tensorflow( compiled_computation.proto, 0) zero_after_transform = tensorflow_computation_test_utils.run_tensorflow( transformed_comp.proto, 0) self.assertEqual(zero_before_transform, zero_after_transform)
def test_returns_computation(self, type_signature, value): proto, _ = tensorflow_computation_factory.create_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type) actual_result = tensorflow_computation_test_utils.run_tensorflow( proto, value) self.assertEqual(actual_result, value)
def test_returns_computation(self): proto, _ = tensorflow_computation_factory.create_empty_tuple() self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, []) expected_type.check_assignable_from(actual_type) actual_result = tensorflow_computation_test_utils.run_tensorflow(proto) expected_result = structure.Struct([]) self.assertEqual(actual_result, expected_result)
def test_returns_computation(self, value, type_signature, expected_result): proto, _ = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) expected_type.check_assignable_from(actual_type) actual_result = tensorflow_computation_test_utils.run_tensorflow(proto) if isinstance(expected_result, list): self.assertCountEqual(actual_result, expected_result) else: self.assertEqual(actual_result, expected_result)
def test_returns_computation(self, type_signature, count, value): proto, _ = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(type_signature, [type_signature] * count) expected_type.check_assignable_from(actual_type) actual_result = tensorflow_computation_test_utils.run_tensorflow( proto, value) expected_result = structure.Struct([(None, value)] * count) self.assertEqual(actual_result, expected_result)
def test_returns_computation(self, operator, type_signature, operands, expected_result): proto, _ = tensorflow_computation_factory.create_binary_operator_with_upcast( type_signature, operator) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) self.assertIsInstance(actual_type, computation_types.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_binary_operator_with_upcast`. expected_parameter_type = _StructType(type_signature) self.assertEqual(actual_type.parameter, expected_parameter_type) actual_result = tensorflow_computation_test_utils.run_tensorflow( proto, operands) self.assertEqual(actual_result, expected_result)