def test_anon_tuple_without_names_promoted_to_container_with_names(self): anon_tuple = structure.Struct([(None, 1), (None, 2.0)]) types = [('a', tf.int32), ('b', tf.float32)] dict_converted_value = type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, dict)) odict_converted_value = type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, collections.OrderedDict)) test_named_tuple = collections.namedtuple('TestNamedTuple', ['a', 'b']) named_tuple_converted_value = type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, test_named_tuple)) @attr.s class TestFoo(object): a = attr.ib() b = attr.ib() attr_converted_value = type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, TestFoo)) self.assertIsInstance(dict_converted_value, dict) self.assertIsInstance(odict_converted_value, collections.OrderedDict) self.assertIsInstance(named_tuple_converted_value, test_named_tuple) self.assertIsInstance(attr_converted_value, TestFoo)
def test_create_selection(self): executor = executor_bindings.create_reference_resolving_executor( executor_bindings.create_tensorflow_executor()) expected_type_spec = TensorType(shape=[3], dtype=tf.int64) value_pb, _ = value_serialization.serialize_value( tf.constant([1, 2, 3]), expected_type_spec) value = executor.create_value(value_pb) self.assertEqual(value.ref, 0) # 1. Create a struct from duplicated values. struct_value = executor.create_struct([value.ref, value.ref]) self.assertEqual(struct_value.ref, 1) materialized_value = executor.materialize(struct_value.ref) deserialized_value, type_spec = value_serialization.deserialize_value( materialized_value) struct_type_spec = computation_types.to_type( [expected_type_spec, expected_type_spec]) type_test_utils.assert_types_equivalent(type_spec, struct_type_spec) deserialized_value = type_conversions.type_to_py_container( deserialized_value, struct_type_spec) self.assertAllClose([(1, 2, 3), (1, 2, 3)], deserialized_value) # 2. Select the first value out of the struct. new_value = executor.create_selection(struct_value.ref, 0) materialized_value = executor.materialize(new_value.ref) deserialized_value, type_spec = value_serialization.deserialize_value( materialized_value) type_test_utils.assert_types_equivalent(type_spec, expected_type_spec) deserialized_value = type_conversions.type_to_py_container( deserialized_value, struct_type_spec) self.assertAllClose((1, 2, 3), deserialized_value)
def _unpack_arg(arg_types, kwarg_types, arg) -> _Arguments: """Unpacks 'arg' into an argument list based on types.""" args = [] for idx, expected_type in enumerate(arg_types): element_value = arg[idx] actual_type = type_conversions.infer_type(element_value) if not expected_type.is_assignable_from(actual_type): raise TypeError( 'Expected element at position {} to be of type {}, found {}.'. format(idx, expected_type, actual_type)) if isinstance(element_value, structure.Struct): element_value = type_conversions.type_to_py_container( element_value, expected_type) args.append(element_value) kwargs = {} for name, expected_type in kwarg_types.items(): element_value = getattr(arg, name) actual_type = type_conversions.infer_type(element_value) if not expected_type.is_assignable_from(actual_type): raise TypeError( 'Expected element named {} to be of type {}, found {}.'.format( name, expected_type, actual_type)) if type_analysis.is_struct_with_py_container(element_value, expected_type): element_value = type_conversions.type_to_py_container( element_value, expected_type) kwargs[name] = element_value return args, kwargs
def test_anon_tuple_with_names_to_container_with_names(self): anon_tuple = anonymous_tuple.AnonymousTuple([('a', 1), ('b', 2.0)]) types = [('a', tf.int32), ('b', tf.float32)] self.assertDictEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, dict)), { 'a': 1, 'b': 2.0 }) self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType( types, collections.OrderedDict)), collections.OrderedDict([('a', 1), ('b', 2.0)])) test_named_tuple = collections.namedtuple('TestNamedTuple', ['a', 'b']) self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, test_named_tuple)), test_named_tuple(a=1, b=2.0)) @attr.s class TestFoo(object): a = attr.ib() b = attr.ib() self.assertEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, TestFoo)), TestFoo(a=1, b=2.0))
def test_anon_tuple_with_names_to_container_without_names_fails(self): anon_tuple = structure.Struct([(None, 1), ('a', 2.0)]) types = [tf.int32, tf.float32] with self.assertRaisesRegex(ValueError, 'contains a mix of named and unnamed elements'): type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, tuple)) anon_tuple = structure.Struct([('a', 1), ('b', 2.0)]) with self.assertRaisesRegex(ValueError, 'which does not support names'): type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, list))
def test_anon_tuple_without_names_to_container_without_names(self): anon_tuple = structure.Struct([(None, 1), (None, 2.0)]) types = [tf.int32, tf.float32] self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, list)), [1, 2.0]) self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, tuple)), (1, 2.0))
def test_fails_without_namedtupletype(self): test_anon_tuple = anonymous_tuple.AnonymousTuple([(None, 1)]) with self.assertRaises(TypeError): type_conversions.type_to_py_container( test_anon_tuple, computation_types.TensorType(tf.int32)) with self.assertRaises(TypeError): type_conversions.type_to_py_container( test_anon_tuple, computation_types.FederatedType( computation_types.TensorType(tf.int32), placement_literals.SERVER))
def test_anon_tuple_without_names_to_container_without_names(self): anon_tuple = anonymous_tuple.AnonymousTuple([(None, 1), (None, 2.0)]) types = [tf.int32, tf.float32] self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.NamedTupleTypeWithPyContainerType( types, list)), [1, 2.0]) self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.NamedTupleTypeWithPyContainerType( types, tuple)), (1, 2.0))
def _unpack_and_call(fn, arg_types, kwarg_types, arg): """An interceptor function that unpacks 'arg' before calling `fn`. The function verifies the actual parameters before it forwards the call as a last-minute check. Args: fn: The function or defun to invoke. arg_types: The list of positional argument types (guaranteed to all be instances of computation_types.Types). kwarg_types: The dictionary of keyword argument types (guaranteed to all be instances of computation_types.Types). arg: The argument to unpack. Returns: The result of invoking `fn` on the unpacked arguments. Raises: TypeError: if types don't match. """ py_typecheck.check_type( arg, (anonymous_tuple.AnonymousTuple, value_base.Value)) args = [] for idx, expected_type in enumerate(arg_types): element_value = arg[idx] actual_type = type_conversions.infer_type(element_value) if not type_analysis.is_assignable_from( expected_type, actual_type): raise TypeError( 'Expected element at position {} to be of type {}, found {}.' .format(idx, expected_type, actual_type)) if isinstance(element_value, anonymous_tuple.AnonymousTuple): element_value = type_conversions.type_to_py_container( element_value, expected_type) args.append(element_value) kwargs = {} for name, expected_type in kwarg_types.items(): element_value = getattr(arg, name) actual_type = type_conversions.infer_type(element_value) if not type_analysis.is_assignable_from( expected_type, actual_type): raise TypeError( 'Expected element named {} to be of type {}, found {}.' .format(name, expected_type, actual_type)) if type_analysis.is_anon_tuple_with_py_container( element_value, expected_type): element_value = type_conversions.type_to_py_container( element_value, expected_type) kwargs[name] = element_value return fn(*args, **kwargs)
def test_anon_tuple_without_names_to_container_with_names_fails(self): anon_tuple = anonymous_tuple.AnonymousTuple([(None, 1), (None, 2.0)]) types = [('a', tf.int32), ('b', tf.float32)] with self.assertRaisesRegex(ValueError, 'value.*with unnamed elements'): type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, dict)) with self.assertRaisesRegex(ValueError, 'value.*with unnamed elements'): type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType( types, collections.OrderedDict)) test_named_tuple = collections.namedtuple('TestNamedTuple', ['a', 'b']) with self.assertRaisesRegex(ValueError, 'value.*with unnamed elements'): type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, test_named_tuple)) @attr.s class TestFoo(object): a = attr.ib() b = attr.ib() with self.assertRaisesRegex(ValueError, 'value.*with unnamed elements'): type_conversions.type_to_py_container( anon_tuple, computation_types.StructWithPythonType(types, TestFoo))
def test_succeeds_with_federated_namedtupletype(self): anon_tuple = anonymous_tuple.AnonymousTuple([(None, 1), (None, 2.0)]) types = [tf.int32, tf.float32] self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.FederatedType( computation_types.StructWithPythonType(types, list), placement_literals.SERVER)), [1, 2.0]) self.assertSequenceEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.FederatedType( computation_types.StructWithPythonType(types, tuple), placement_literals.SERVER)), (1, 2.0))
def invoke(self, comp: computation_base.Computation, arg: Any) -> Any: """Invokes the `comp` with the argument `arg`. Args: comp: The `tff.Computation` being invoked. arg: The optional argument of `comp`. Returns: The result of invocation, must contain only structures, server-placed values, or tensors. Raises: ValueError: If the result type of the invoked comptuation does not contain only structures, server-placed values, or tensors. """ py_typecheck.check_type(comp, computation_base.Computation) result_type = comp.type_signature.result if not federated_context.contains_only_server_placed_data(result_type): raise ValueError( 'Expected the result type of the invoked computation to contain only ' 'structures, server-placed values, or tensors, found ' f'\'{result_type}\'.') async def _invoke(context: context_base.Context, comp: computation_base.Computation, arg: Any) -> Any: if comp.type_signature.parameter is not None: arg = await _materialize_structure_of_value_references( arg, comp.type_signature.parameter) return await context.invoke(comp, arg) result_coro = _invoke(self._context, comp, arg) result = _create_structure_of_coro_references(result_coro, result_type) result = type_conversions.type_to_py_container(result, result_type) return result
def federated_secure_sum(arg): py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) summand_arg = building_blocks.Selection(arg, index=0) summand_type = summand_arg.type_signature.member max_input_arg = building_blocks.Selection(arg, index=1) max_input_type = max_input_arg.type_signature # Add the max_value as a second value in the zero, so it can be read during # `accumulate` to ensure client summands are valid. The value will be # later dropped in `report`. # # While accumulating summands, we'll assert each summand is less than or # equal to max_input. Otherwise the comptuation should issue an error. summation_zero = building_block_factory.create_generic_constant( summand_type, 0) aggregation_zero = building_blocks.Struct([summation_zero, max_input_arg], container_type=tuple) def assert_less_equal_max_and_add(summation_and_max_input, summand): summation, original_max_input = summation_and_max_input max_input = _ensure_structure(original_max_input, max_input_type, summand_type) # Assert that all coordinates in all tensors are less than the secure sum # allowed max input value. def assert_all_coordinates_less_equal(x, m): return tf.Assert( tf.reduce_all( tf.less_equal(tf.cast(x, tf.int64), tf.cast(m, tf.int64))), [ 'client value larger than maximum specified for secure sum', x, 'not less than or equal to', m ]) assert_ops = structure.flatten( structure.map_structure(assert_all_coordinates_less_equal, summand, max_input)) with tf.control_dependencies(assert_ops): return structure.map_structure(tf.add, summation, summand), original_max_input assert_less_equal_and_add = building_block_factory.create_tensorflow_binary_operator( assert_less_equal_max_and_add, operand_type=aggregation_zero.type_signature, second_operand_type=summand_type) def nested_plus(a, b): return structure.map_structure(tf.add, a, b) plus_op = building_block_factory.create_tensorflow_binary_operator( nested_plus, operand_type=aggregation_zero.type_signature) # In the `report` function we take the summation and drop the second element # of the struct (which was holding the max_value). drop_max_value_op = building_block_factory.create_tensorflow_unary_operator( lambda x: type_conversions.type_to_py_container(x[0], summand_type), aggregation_zero.type_signature) return building_block_factory.create_federated_aggregate( summand_arg, aggregation_zero, assert_less_equal_and_add, plus_op, drop_max_value_op)
def test_nested_py_containers(self): anon_tuple = anonymous_tuple.AnonymousTuple([ (None, 1), (None, 2.0), ('dict_key', anonymous_tuple.AnonymousTuple([ ('a', 3), ('b', anonymous_tuple.AnonymousTuple([(None, 4), (None, 5)])) ])) ]) dict_subtype = computation_types.StructWithPythonType( [('a', tf.int32), ('b', computation_types.StructWithPythonType([tf.int32, tf.int32], tuple))], dict) type_spec = computation_types.StructType([(None, tf.int32), (None, tf.float32), ('dict_key', dict_subtype)]) expected_nested_structure = anonymous_tuple.AnonymousTuple([ (None, 1), (None, 2.0), ('dict_key', { 'a': 3, 'b': (4, 5) }), ]) self.assertEqual( type_conversions.type_to_py_container(anon_tuple, type_spec), expected_nested_structure)
def test_anon_tuple_return(self): anon_tuple = anonymous_tuple.AnonymousTuple([(None, 1), (None, 2.0)]) self.assertEqual( type_conversions.type_to_py_container( anon_tuple, computation_types.StructType([tf.int32, tf.float32])), anon_tuple)
def invoke(self, comp: Union[MergeableCompForm, computation_base.Computation], arg: Optional[Any] = None): py_typecheck.check_type( comp, (MergeableCompForm, computation_base.Computation)) if isinstance(comp, computation_base.Computation): if self._compiler_pipeline is None: raise ValueError( 'Without a compiler, mergeable comp execution context ' 'can only invoke instances of MergeableCompForm. ' 'Encountered a `tff.Computation`.') comp = self._compiler_pipeline.compile(comp) if not isinstance(comp, MergeableCompForm): raise ValueError( 'Expected compilation in mergeable comp execution ' 'context to produce an instance of MergeableCompForm; ' f'found instead {comp} of Python type {type(comp)}.') if arg is not None: arg = MergeableCompExecutionContextValue( arg, comp.up_to_merge.type_signature.parameter, len(self._async_execution_contexts)) return type_conversions.type_to_py_container( self._async_runner.run_coro_and_return_result( _invoke_mergeable_comp_form(comp, arg, self._async_execution_contexts)), comp.after_merge.type_signature.result)
def test_not_anon_tuple_passthrough(self): value = (1, 2.0) result = type_conversions.type_to_py_container( (1, 2.0), computation_types.StructWithPythonType([tf.int32, tf.float32], container_type=list)) self.assertEqual(result, value)
def _call(fn, parameter_type, arg): arg_type = type_conversions.infer_type(arg) if not parameter_type.is_assignable_from(arg_type): raise TypeError('Expected an argument of type {}, found {}.'.format( parameter_type, arg_type)) if type_analysis.is_anon_tuple_with_py_container(arg, parameter_type): arg = type_conversions.type_to_py_container(arg, parameter_type) return fn(arg)
def test_sequence_type_with_collections_sequence_elements(self): dataset_yielding_sequences = tf.data.Dataset.range(5).map(lambda t: (t, t)) converted_dataset = type_conversions.type_to_py_container( dataset_yielding_sequences, computation_types.SequenceType((tf.int64, tf.int64))) actual_elements = list(converted_dataset) expected_elements = list(dataset_yielding_sequences) self.assertAllEqual(actual_elements, expected_elements)
def _train_on_one_batch(model, batch): params = structure.flatten( structure.from_container(model, recursive=True)) grads = structure.flatten( structure.from_container(jax.api.grad(loss_fn)(model, batch))) updated_params = [_apply_update(x, y) for (x, y) in zip(params, grads)] trained_model = structure.pack_sequence_as(model_type, updated_params) return type_conversions.type_to_py_container(trained_model, model_type)
def _ensure_arg_type(parameter_type, arg) -> _Arguments: """Ensures that `arg` matches `parameter_type` before returning it.""" arg_type = type_conversions.infer_type(arg) if not parameter_type.is_assignable_from(arg_type): raise TypeError('Expected an argument of type {}, found {}.'.format( parameter_type, arg_type)) if type_analysis.is_struct_with_py_container(arg, parameter_type): arg = type_conversions.type_to_py_container(arg, parameter_type) return [arg], {}
def test_client_placed_tuple(self): value = [ structure.Struct([(None, 1), (None, 2)]), structure.Struct([(None, 3), (None, 4)]) ] type_spec = computation_types.FederatedType( computation_types.StructWithPythonType([(None, tf.int32), (None, tf.int32)], tuple), placement_literals.CLIENTS) self.assertEqual([(1, 2), (3, 4)], type_conversions.type_to_py_container(value, type_spec))
def test_create_value(self): executor = executor_bindings.create_reference_resolving_executor( executor_bindings.create_tensorflow_executor()) # 1. Test a simple tensor. expected_type_spec = TensorType(shape=[3], dtype=tf.int64) value_pb, _ = value_serialization.serialize_value([1, 2, 3], expected_type_spec) value = executor.create_value(value_pb) self.assertIsInstance(value, executor_bindings.OwnedValueId) self.assertEqual(value.ref, 0) self.assertEqual(str(value), '0') self.assertEqual(repr(value), r'<OwnedValueId: 0>') materialized_value = executor.materialize(value.ref) deserialized_value, type_spec = value_serialization.deserialize_value( materialized_value) type_test_utils.assert_types_identical(type_spec, expected_type_spec) self.assertAllEqual(deserialized_value, [1, 2, 3]) # 2. Test a struct of tensors, ensure that we get a different ID. expected_type_spec = StructType([ ('a', TensorType(shape=[3], dtype=tf.int64)), ('b', TensorType(shape=[], dtype=tf.float32)) ]) value_pb, _ = value_serialization.serialize_value( collections.OrderedDict(a=tf.constant([1, 2, 3]), b=tf.constant(42.0)), expected_type_spec) value = executor.create_value(value_pb) self.assertIsInstance(value, executor_bindings.OwnedValueId) # Assert the value ID was incremented. self.assertEqual(value.ref, 1) self.assertEqual(str(value), '1') self.assertEqual(repr(value), r'<OwnedValueId: 1>') materialized_value = executor.materialize(value.ref) deserialized_value, type_spec = value_serialization.deserialize_value( materialized_value) # Note: here we've lost the names `a` and `b` in the output. The output # is a more _strict_ type. self.assertTrue(expected_type_spec.is_assignable_from(type_spec)) deserialized_value = type_conversions.type_to_py_container( deserialized_value, expected_type_spec) self.assertAllClose(deserialized_value, collections.OrderedDict(a=(1, 2, 3), b=42.0)) # 3. Test creating a value from a computation. @tensorflow_computation.tf_computation(tf.int32, tf.int32) def foo(a, b): return tf.add(a, b) value_pb, _ = value_serialization.serialize_value(foo) value = executor.create_value(value_pb) self.assertIsInstance(value, executor_bindings.OwnedValueId) # Assert the value ID was incremented again. self.assertEqual(value.ref, 2) self.assertEqual(str(value), '2') self.assertEqual(repr(value), '<OwnedValueId: 2>')
def invoke(self, comp, arg): # We are invoking a tff.tf_computation inside of another # tf_computation. py_typecheck.check_type(comp, computation_base.Computation) computation_proto = computation_impl.ComputationImpl.get_proto(comp) init_op, result = ( tensorflow_deserialization.deserialize_and_call_tf_computation( computation_proto, arg, self._graph)) if init_op: self._init_ops.append(init_op) return type_conversions.type_to_py_container( result, comp.type_signature.result)
def test_ragged_tensor(self): value = structure.Struct([ ('flat_values', [0, 0, 0, 0]), ('nested_row_splits', [[0, 1, 4]]), ]) value_type = computation_types.StructWithPythonType([ ('flat_values', computation_types.TensorType(tf.int32, [4])), ('nested_row_splits', computation_types.StructWithPythonType([ (None, computation_types.TensorType(tf.int64, [3])) ], tuple)), ], tf.RaggedTensor) result = type_conversions.type_to_py_container(value, value_type) self.assertIsInstance(result, tf.RaggedTensor) self.assertAllEqual(result.flat_values, [0, 0, 0, 0]) self.assertEqual(len(result.nested_row_splits), 1) self.assertAllEqual(result.nested_row_splits[0], [0, 1, 4])
def test_sparse_tensor(self): value = structure.Struct([ ('indices', [[1]]), ('values', [2]), ('dense_shape', [5]), ]) value_type = computation_types.StructWithPythonType([ ('indices', computation_types.TensorType(tf.int64, [1, 1])), ('values', computation_types.TensorType(tf.int32, [1])), ('dense_shape', computation_types.TensorType(tf.int64, [1])), ], tf.SparseTensor) result = type_conversions.type_to_py_container(value, value_type) self.assertIsInstance(result, tf.SparseTensor) self.assertEqual(len(result.indices), 1) self.assertAllEqual(result.indices[0], [1]) self.assertAllEqual(result.values, [2]) self.assertAllEqual(result.dense_shape, [5])
async def invoke(self, comp, arg): compiled_comp = self._compiler_pipeline.compile(comp) serialized_comp, _ = value_serialization.serialize_value( compiled_comp, comp.type_signature) cardinalities = self._cardinality_inference_fn( arg, comp.type_signature.parameter) try: with self._reset_factory_on_error(self._executor_factory, cardinalities) as executor: fn = executor.create_value(serialized_comp) if arg is not None: try: serialized_arg, _ = value_serialization.serialize_value( arg, comp.type_signature.parameter) except Exception as e: raise TypeError( f'Failed to serialize argument:\n{arg}\nas a value of type:\n' f'{comp.type_signature.parameter}') from e arg_value = executor.create_value(serialized_arg) call = executor.create_call(fn.ref, arg_value.ref) else: call = executor.create_call(fn.ref, None) # Delaying grabbing the event loop til now ensures that the call below # is attached to the loop running the invoke. running_loop = asyncio.get_running_loop() result_pb = await running_loop.run_in_executor( self._futures_executor_pool, lambda: executor.materialize(call.ref)) except absl_status.StatusNotOk as e: indent = lambda s: textwrap.indent(s, prefix='\t') if arg is None: arg_str = 'without any arguments' else: arg_str = f'with argument:\n{indent(pprint.pformat(arg))}' logging.error( 'Error invoking computation with signature:\n%s\n%s\n', indent(comp.type_signature.formatted_representation()), arg_str) logging.error('Error: \n%s', e.status.message()) raise e result_value, _ = value_serialization.deserialize_value( result_pb, comp.type_signature.result) return type_conversions.type_to_py_container( result_value, comp.type_signature.result)
def invoke(self, comp, arg): cardinalities = cardinalities_utils.infer_cardinalities( arg, comp.type_signature.parameter) executor = self._executor_factory.create_executor(cardinalities) embedded_comp = self._loop.run_until_complete( executor.create_value(comp, comp.type_signature)) if arg is not None: if isinstance(arg, ingestable_base.Ingestable): arg_coro = self._loop.run_until_complete(arg.ingest(executor)) arg = self._loop.run_until_complete(arg_coro.compute()) embedded_arg = self._loop.run_until_complete( executor.create_value(arg, comp.type_signature.parameter)) else: embedded_arg = None embedded_call = self._loop.run_until_complete( executor.create_call(embedded_comp, embedded_arg)) return type_conversions.type_to_py_container( self._unwrap_tensors( self._loop.run_until_complete(embedded_call.compute())), comp.type_signature.result)
def test_computes_sum_of_all_values(self, arg, expected_sum): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) merge = build_sum_merge_computation(tf.int32) after_merge = build_sum_merge_with_first_arg_computation( up_to_merge.type_signature.parameter, merge.type_signature.result) mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge) ex_factories = [ python_executor_stacks.local_executor_factory() for _ in range(5) ] mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext( ex_factories) expected_result = type_conversions.type_to_py_container( expected_sum, after_merge.type_signature.result) result = mergeable_comp_context.invoke(mergeable_comp_form, arg) self.assertEqual(expected_result, result)
async def _invoke(executor, comp, arg, result_type: computation_types.Type): """A coroutine that handles invocation. Args: executor: An instance of `executor_base.Executor`. comp: The first argument to `context_base.Context.invoke()`. arg: The optional second argument to `context_base.Context.invoke()`. result_type: The type signature of the result. This is used to convert the execution result into the proper container types. Returns: The result of the invocation. """ if arg is not None: py_typecheck.check_type(arg, executor_value_base.ExecutorValue) comp = await executor.create_value(comp) result = await executor.create_call(comp, arg) py_typecheck.check_type(result, executor_value_base.ExecutorValue) result_val = _unwrap(await result.compute()) return type_conversions.type_to_py_container(result_val, result_type)