def test_serialize_deserialize_round_trip(self): operand_type = computation_types.TensorType(tf.int32) proto, _ = tensorflow_computation_factory.create_binary_operator( tf.add, operand_type, operand_type) comp = computation_impl.ConcreteComputation( proto, context_stack_impl.context_stack) serialized_comp = computation_serialization.serialize_computation(comp) deserialize_comp = computation_serialization.deserialize_computation( serialized_comp) self.assertIsInstance(deserialize_comp, computation_base.Computation) self.assertEqual(deserialize_comp, comp)
def test_returns_computation(self, operator, type_signature, operands, expected_result): proto, _ = tensorflow_computation_factory.create_binary_operator( operator, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) self.assertIsInstance(actual_type, computation_types.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_binary_operator`. expected_parameter_type = computation_types.StructType( [type_signature, type_signature]) self.assertEqual(actual_type.parameter, expected_parameter_type) actual_result = test_utils.run_tensorflow(proto, operands) self.assertEqual(actual_result, expected_result)
def test_get_curried(self): operand_type = computation_types.TensorType(tf.int32) computation_proto, type_signature = tensorflow_computation_factory.create_binary_operator( tf.add, operand_type, operand_type) building_block = building_blocks.CompiledComputation( proto=computation_proto, name='test', type_signature=type_signature) add_numbers = value_impl.Value(building_block) curried = value_utils.get_curried(add_numbers) self.assertEqual(curried.type_signature.compact_representation(), '(int32 -> (int32 -> int32))') self.assertEqual(curried.comp.compact_representation(), '(arg0 -> (arg1 -> comp#test(<arg0,arg1>)))')
async def embed_tf_binary_operator(executor, type_spec, op): """Embeds a binary operator `op` on `type_spec`-typed values in `executor`. Args: executor: An instance of `tff.framework.Executor`. type_spec: An instance of `tff.Type` of the type of values that the binary operator accepts as input and returns as output. op: An operator function (such as `tf.add` or `tf.multiply`) to apply to the tensor-level constituents of the values, pointwise. Returns: An instance of `tff.framework.ExecutorValue` representing the operator in a form embedded into the executor. """ proto, type_signature = tensorflow_computation_factory.create_binary_operator( op, type_spec) return await executor.create_value(proto, type_signature)
def test_raises_type_error(self, operator, type_signature): with self.assertRaises(TypeError): tensorflow_computation_factory.create_binary_operator( operator, type_signature)
def _create_computation_add() -> computation_base.Computation: operand_type = computation_types.TensorType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_binary_operator( tf.add, operand_type, operand_type) return computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack)