def test_federated_min_on_nested_scalars(self): tuple_type = computation_types.StructType([ ('x', tf.float32), ('y', tf.float32), ]) @computations.federated_computation( computation_types.FederatedType(tuple_type, placements.CLIENTS)) def call_federated_min(value): return federated_aggregations.federated_min(value) test_type = collections.namedtuple('NestedScalars', ['x', 'y']) value = call_federated_min( [test_type(0.0, 1.0), test_type(-1.0, 5.0), test_type(2.0, -10.0)]) self.assertDictEqual(value._asdict(), {'x': -1.0, 'y': -10.0})
def test_returns_value_with_source_and_index_structure(self): executor = create_test_executor() element, element_type = executor_test_utils.create_dummy_value_unplaced( ) element = self.run_sync(executor.create_value(element, element_type)) elements = [element] * 3 type_signature = computation_types.StructType([element_type] * 3) source = self.run_sync(executor.create_struct(elements)) result = self.run_sync(executor.create_selection(source, 0)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assertEqual(result.type_signature.compact_representation(), type_signature[0].compact_representation()) actual_result = self.run_sync(result.compute()) expected_result = self.run_sync(source.compute())[0] self.assertEqual(actual_result, expected_result)
def test_returns_value_with_elements_fn_and_arg(self, fn, fn_type, arg, arg_type): executor = create_test_executor() fn = self.run_sync(executor.create_value(fn, fn_type)) arg = self.run_sync(executor.create_value(arg, arg_type)) element = self.run_sync(executor.create_call(fn, arg)) elements = [element] * 3 type_signature = computation_types.StructType([fn_type.result] * 3) result = self.run_sync(executor.create_struct(elements)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assertEqual(result.type_signature.compact_representation(), type_signature.compact_representation()) actual_result = self.run_sync(result.compute()) expected_result = [self.run_sync(element.compute())] * 3 self.assertCountEqual(actual_result, expected_result)
def test_federated_max_on_nested_scalars(self): tuple_type = computation_types.StructType([ ('a', tf.int32), ('b', tf.int32), ]) @computations.federated_computation( computation_types.FederatedType(tuple_type, placements.CLIENTS)) def call_federated_max(value): return federated_aggregations.federated_max(value) test_type = collections.namedtuple('NestedScalars', ['a', 'b']) value = call_federated_max( [test_type(1, 5), test_type(2, 3), test_type(1, 8)]) self.assertDictEqual(value._asdict(), {'a': 2, 'b': 8})
def test_generic_add_with_unplaced_named_tuple_and_tensor(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.StructType([[('a', tf.float32), ('b', tf.float32)], tf.float32])) def foo(x): return bodies[intrinsic_defs.GENERIC_PLUS.uri](x) self.assertEqual( str(foo.type_signature), '(<<a=float32,b=float32>,float32> -> <a=float32,b=float32>)') self.assertEqual( foo([[1., 1.], 1.]), anonymous_tuple.AnonymousTuple([('a', 2.), ('b', 2.)]))
def test_create_selection_does_not_cache_error(self): loop = asyncio.get_event_loop() mock_executor = mock.create_autospec(executor_base.Executor) mock_executor.create_value.side_effect = create_test_value mock_executor.create_selection.side_effect = raise_error cached_executor = caching_executor.CachingExecutor(mock_executor) value = loop.run_until_complete( cached_executor.create_value((1, 2), computation_types.StructType( (tf.int32, tf.int32)))) with self.assertRaises(TestError): _ = loop.run_until_complete(cached_executor.create_selection(value, 1)) with self.assertRaises(TestError): _ = loop.run_until_complete(cached_executor.create_selection(value, 1)) # Ensure create_struct was called twice on the mock (not cached and only # called once). mock_executor.create_selection.assert_has_calls([])
def test_call_tf_comp_with_int_tuple(self): comp = computations.tf_computation(lambda x, y: x + y, tf.int32, tf.int32) comp_pb = computation_impl.ComputationImpl.get_proto(comp) comp_type = comp.type_signature comp_val = _run_sync( self._sequence_executor.create_value(comp_pb, comp_type)) arg = collections.OrderedDict([('a', 10), ('b', 20)]) arg_type = computation_types.StructType( collections.OrderedDict([('a', tf.int32), ('b', tf.int32)])) arg_val = _run_sync(self._sequence_executor.create_value( arg, arg_type)) result_val = _run_sync( self._sequence_executor.create_call(comp_val, arg_val)) self.assertIsInstance(result_val, sequence_executor.SequenceExecutorValue) self.assertEqual(str(result_val.type_signature), 'int32') self.assertEqual(_run_sync(result_val.compute()), 30)
async def compute_federated_zip_at_server( self, arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue: py_typecheck.check_type(arg.type_signature, computation_types.StructType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 2) for n in [0, 1]: type_analysis.check_federated_type( arg.type_signature[n], placement=placement_literals.SERVER, all_equal=True) return FederatedComposingStrategyValue( await self._server_executor.create_struct( [arg.internal_representation[n] for n in [0, 1]]), computation_types.at_server( computation_types.StructType( [arg.type_signature[0].member, arg.type_signature[1].member])))
def test_n_tuple_federated_zip_tensor_args(self, n): fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) initial_tuple_type = computation_types.StructType([fed_type] * n) final_fed_type = computation_types.FederatedType([tf.int32] * n, placements.CLIENTS) function_type = computation_types.FunctionType(initial_tuple_type, final_fed_type) @computations.federated_computation( [computation_types.FederatedType(tf.int32, placements.CLIENTS)] * n ) def foo(x): val = intrinsics.federated_zip(x) self.assertIsInstance(val, value_base.Value) return val self.assert_type(foo, function_type.compact_representation())
async def compute_federated_sum( self, arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue: py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) zero, plus = await asyncio.gather( executor_utils.embed_tf_constant(self._executor, arg.type_signature.member, 0), executor_utils.embed_tf_binary_operator(self._executor, arg.type_signature.member, tf.add)) return await self.compute_federated_reduce( FederatedResolvingStrategyValue( structure.Struct([(None, arg.internal_representation), (None, zero.internal_representation), (None, plus.internal_representation)]), computation_types.StructType( (arg.type_signature, zero.type_signature, plus.type_signature))) )
def test_with_selection_by_index(self): ex, _ = _make_executor_and_tracer_for_test() loop = asyncio.get_event_loop() v1 = loop.run_until_complete( ex.create_value([10, 20], computation_types.StructType([tf.int32, tf.int32]))) self.assertEqual(str(v1.identifier), '1') v2 = loop.run_until_complete(ex.create_selection(v1, index=0)) self.assertEqual(str(v2.identifier), '1[0]') v3 = loop.run_until_complete(ex.create_selection(v1, index=1)) self.assertEqual(str(v3.identifier), '1[1]') v4 = loop.run_until_complete(ex.create_selection(v1, index=0)) self.assertIs(v4, v2) v5 = loop.run_until_complete(ex.create_selection(v1, index=1)) self.assertIs(v5, v3) c5 = loop.run_until_complete(v5.compute()) self.assertEqual(c5.numpy(), 20)
def test_returns_value_with_source_and_name_structure(self): executor = create_test_executor() element, element_type = executor_test_utils.create_dummy_value_unplaced() names = ['a', 'b', 'c'] element = self.run_sync(executor.create_value(element, element_type)) elements = structure.Struct((n, element) for n in names) type_signature = computation_types.StructType( (n, element_type) for n in names) source = self.run_sync(executor.create_struct(elements)) result = self.run_sync(executor.create_selection(source, name='a')) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assertEqual(result.type_signature.compact_representation(), type_signature['a'].compact_representation()) actual_result = self.run_sync(result.compute()) expected_result = self.run_sync(source.compute())['a'] self.assertEqual(actual_result, expected_result)
def test_is_assignable_from(self): t1 = computation_types.StructType([tf.int32, ('a', tf.bool)]) t2 = computation_types.StructType([tf.int32, ('a', tf.bool)]) t3 = computation_types.StructType([tf.int32, ('b', tf.bool)]) t4 = computation_types.StructType([tf.int32, ('a', tf.string)]) t5 = computation_types.StructType([tf.int32]) t6 = computation_types.StructType([tf.int32, tf.bool]) self.assertTrue(t1.is_assignable_from(t2)) self.assertFalse(t1.is_assignable_from(t3)) self.assertFalse(t1.is_assignable_from(t4)) self.assertFalse(t1.is_assignable_from(t5)) self.assertTrue(t1.is_assignable_from(t6)) self.assertFalse(t6.is_assignable_from(t1))
def test_to_representation_for_type_with_noarg_to_2xint32_comp(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Tuple(builder, [ xla_client.ops.Constant(builder, np.int32(10)), xla_client.ops.Constant(builder, np.int32(20)) ]) xla_comp = builder.build() comp_type = computation_types.FunctionType( None, computation_types.StructType([('a', np.int32), ('b', np.int32)])) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep() self.assertEqual(str(result), '<a=10,b=20>')
async def create_struct(self, elements): constructed_anon_tuple = structure.from_container(elements) proto_elem = [] type_elem = [] for k, v in structure.iter_elements(constructed_anon_tuple): py_typecheck.check_type(v, RemoteValue) proto_elem.append( executor_pb2.CreateStructRequest.Element( name=(k if k else None), value_ref=v.value_ref)) type_elem.append((k, v.type_signature) if k else v.type_signature) result_type = computation_types.StructType(type_elem) request = executor_pb2.CreateStructRequest(element=proto_elem) if self._bidi_stream is None: response = _request(self._stub.CreateStruct, request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_struct=request))).create_struct py_typecheck.check_type(response, executor_pb2.CreateStructResponse) return RemoteValue(response.value_ref, result_type, self)
def test_federated_max_nested_tensor_value(self): tuple_type = computation_types.StructType([ ('a', (tf.int32, [2])), ('b', (tf.int32, [3])), ]) @computations.federated_computation( computation_types.FederatedType(tuple_type, placements.CLIENTS)) def call_federated_max(value): return federated_aggregations.federated_max(value) test_type = collections.namedtuple('NestedScalars', ['a', 'b']) client1 = test_type( np.array([4, 5], dtype=np.int32), np.array([1, -2, 3], dtype=np.int32)) client2 = test_type( np.array([9, 0], dtype=np.int32), np.array([5, 1, -2], dtype=np.int32)) value = call_federated_max([client1, client2]) self.assertCountEqual(value[0], [9, 5]) self.assertCountEqual(value[1], [5, 1, 3])
def test_ordered_dict(self): a = computation_types.TensorType(tf.string, [4]) b = computation_types.TensorType(tf.int64, [2, 3]) tup = computation_types.StructType([('a', a), ('b', b)]) ex = sizing_executor.SizingExecutor( eager_tf_executor.EagerTFExecutor()) od = collections.OrderedDict() od['a'] = ['some', 'arbitrary', 'string', 'here'] od['b'] = [[3, 4, 1], [6, 8, -5]] total_string_length = sum([len(s) for s in od['a']]) async def _make(): v1 = await ex.create_value(od, tup) return await v1.compute() asyncio.get_event_loop().run_until_complete(_make()) self.assertCountEqual( ex.broadcast_history, [[total_string_length, tf.string], [6, tf.int64]])
class IsAverageCompatibleTest(parameterized.TestCase): @parameterized.named_parameters([ ('tensor_type_float32', computation_types.TensorType(tf.float32)), ('tensor_type_float64', computation_types.TensorType(tf.float64)), ('tuple_type', computation_types.StructType([('x', tf.float32), ('y', tf.float64)])), ('federated_type', computation_types.FederatedType(tf.float32, placements.CLIENTS)), ]) def test_returns_true(self, type_spec): self.assertTrue(type_analysis.is_average_compatible(type_spec)) @parameterized.named_parameters([ ('tensor_type_int32', computation_types.TensorType(tf.int32)), ('tensor_type_int64', computation_types.TensorType(tf.int64)), ('sequence_type', computation_types.SequenceType(tf.float32)), ]) def test_returns_false(self, type_spec): self.assertFalse(type_analysis.is_average_compatible(type_spec))
def test_serialize_jax_with_two_args(self): ctx_stack = context_stack_impl.context_stack param_type = computation_types.StructType([('a', np.int32), ('b', np.int32)]) arg_func = lambda arg: ([], {'x': arg[0], 'y': arg[1]}) def traced_func(x, y): return x + y comp_pb = jax_serialization.serialize_jax_computation( traced_func, arg_func, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(<a=int32,b=int32> -> int32)') xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module) self.assertEqual( xla_comp.as_hlo_text(), # pylint: disable=line-too-long 'HloModule xla_computation_traced_func__3.7\n\n' 'ENTRY xla_computation_traced_func__3.7 {\n' ' constant.4 = pred[] constant(false)\n' ' parameter.1 = (s32[], s32[]) parameter(0)\n' ' get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n' ' get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n' ' add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n' ' ROOT tuple.6 = (s32[]) tuple(add.5)\n' '}\n\n') self.assertEqual( str(comp_pb.xla.parameter), 'struct {\n' ' element {\n' ' tensor {\n' ' index: 0\n' ' }\n' ' }\n' ' element {\n' ' tensor {\n' ' index: 1\n' ' }\n' ' }\n' '}\n') self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
async def create_struct(self, elements): if not isinstance(elements, structure.Struct): elements = structure.from_container(elements) element_strings = [] element_kv_pairs = structure.to_elements(elements) to_gather = [] type_elements = [] for k, v in element_kv_pairs: py_typecheck.check_type(v, CachedValue) to_gather.append(v.target_future) if k is not None: py_typecheck.check_type(k, str) element_strings.append('{}={}'.format(k, v.identifier)) type_elements.append((k, v.type_signature)) else: element_strings.append(str(v.identifier)) type_elements.append(v.type_signature) type_spec = computation_types.StructType(type_elements) gathered = await asyncio.gather(*to_gather) identifier = CachedValueIdentifier('<{}>'.format( ','.join(element_strings))) try: cached_value = self._cache[identifier] except KeyError: target_future = asyncio.ensure_future( self._target_executor.create_struct( structure.Struct( (k, v) for (k, _), v in zip(element_kv_pairs, gathered)))) cached_value = CachedValue(identifier, None, type_spec, target_future) self._cache[identifier] = cached_value try: target_value = await cached_value.target_future except Exception: # TODO(b/145514490): This is a bit heavy handed, there maybe caches where # only the current cache item needs to be invalidated; however this # currently only occurs when an inner RemoteExecutor has the backend go # down. self._cache = {} raise type_spec.check_assignable_from(target_value.type_signature) return cached_value
def test_serialize_type_with_tensor_tuple(self): type_signature = computation_types.StructType([ ('x', tf.int32), ('y', tf.string), tf.float32, ('z', tf.bool), ]) actual_proto = type_serialization.serialize_type(type_signature) expected_proto = pb.Type(struct=pb.StructType(element=[ pb.StructType.Element(name='x', value=_create_scalar_tensor_type(tf.int32)), pb.StructType.Element(name='y', value=_create_scalar_tensor_type(tf.string)), pb.StructType.Element( value=_create_scalar_tensor_type(tf.float32)), pb.StructType.Element(name='z', value=_create_scalar_tensor_type(tf.bool)), ])) self.assertEqual(actual_proto, expected_proto)
def _create_xla_binary_op_computation(type_spec, xla_binary_op_constructor): """Helper for constructing computations that implement binary operators. The constructed computation is of type `(<T,T> -> T)`, where `T` is the type of the operand (`type_spec`). Args: type_spec: The type of a single operand. xla_binary_op_constructor: A two-argument callable that constructs a binary xla op from tensor parameters (such as `xla_client.ops.Add` or similar). Returns: An instance of `local_computation_factory_base.ComputationProtoAndType`. Raises: ValueError: if the arguments are invalid. """ py_typecheck.check_type(type_spec, computation_types.Type) if not type_analysis.is_structure_of_tensors(type_spec): raise ValueError('Not a tensor or a structure of tensors: {}'.format( str(type_spec))) tensor_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type( type_spec) num_tensors = len(tensor_shapes) builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.Shape.tuple_shape(tensor_shapes * 2)) result_tensors = [] for idx in range(num_tensors): result_tensors.append( xla_binary_op_constructor( xla_client.ops.GetTupleElement(param, idx), xla_client.ops.GetTupleElement(param, idx + num_tensors))) xla_client.ops.Tuple(builder, result_tensors) xla_computation = builder.build() comp_type = computation_types.FunctionType( computation_types.StructType([(None, type_spec)] * 2), type_spec) comp_pb = xla_serialization.create_xla_tff_computation( xla_computation, list(range(2 * num_tensors)), comp_type) return (comp_pb, comp_type)
def test_with_two_level_tuple(self): type_signature = computation_types.StructWithPythonType([ ('a', tf.bool), ('b', computation_types.StructWithPythonType([ ('c', computation_types.TensorType(tf.float32)), ('d', computation_types.TensorType(tf.int32, [20])), ], collections.OrderedDict)), ('e', computation_types.StructType([])), ], collections.OrderedDict) tensor_specs = type_conversions.type_to_tf_tensor_specs(type_signature) test.assert_nested_struct_eq( tensor_specs, { 'a': tf.TensorSpec([], tf.bool), 'b': { 'c': tf.TensorSpec([], tf.float32), 'd': tf.TensorSpec([20], tf.int32) }, 'e': (), })
async def create_struct(self, elements): """Creates a tuple of `elements`. Args: elements: As documented in `executor_base.Executor`. Returns: An instance of `EagerValue` that represents the constructed tuple. """ elements = structure.to_elements(structure.from_container(elements)) val_elements = [] type_elements = [] for k, v in elements: py_typecheck.check_type(v, EagerValue) val_elements.append((k, v.internal_representation)) type_elements.append((k, v.type_signature)) return EagerValue( structure.Struct(val_elements), self._tf_function_cache, computation_types.StructType([(k, v) if k is not None else v for k, v in type_elements]))
def test_serialize_type_with_nested_tuple(self): type_signature = computation_types.StructType([ ('x', [('y', [('z', tf.bool)])]), ]) actual_proto = type_serialization.serialize_type(type_signature) def _tuple_type_proto(elements): return pb.Type(struct=pb.StructType(element=elements)) z_proto = pb.StructType.Element( name='z', value=_create_scalar_tensor_type(tf.bool)) expected_proto = _tuple_type_proto([ pb.StructType.Element( name='x', value=_tuple_type_proto([ pb.StructType.Element( name='y', value=_tuple_type_proto([z_proto])) ])) ]) self.assertEqual(actual_proto, expected_proto)
def test_str(self): self.assertEqual(str(computation_types.StructType([tf.int32])), '<int32>') self.assertEqual( str(computation_types.StructType([('a', tf.int32)])), '<a=int32>') self.assertEqual( str(computation_types.StructType(('a', tf.int32))), '<a=int32>') self.assertEqual( str(computation_types.StructType([tf.int32, tf.bool])), '<int32,bool>') self.assertEqual( str(computation_types.StructType([('a', tf.int32), tf.float32])), '<a=int32,float32>') self.assertEqual( str(computation_types.StructType([('a', tf.int32), ('b', tf.float32)])), '<a=int32,b=float32>') self.assertEqual( str( computation_types.StructType([('a', tf.int32), ('b', computation_types.StructType([ ('x', tf.string), ('y', tf.bool) ]))])), '<a=int32,b=<x=string,y=bool>>')
def test_returns_computation(self, operator, type_signature, operands, expected_result): # TODO(b/142795960): arguments in parameterized are called before test main. # `tf.constant` will error out on GPU and TPU without proper initialization. # A suggested workaround is to use numpy as argument and transform to TF # tensor inside the function. operands = tf.nest.map_structure(tf.constant, operands) proto, _ = tensorflow_computation_factory.create_binary_operator_with_upcast( type_signature, operator) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) self.assertIsInstance(actual_type, computation_types.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_binary_operator_with_upcast`. expected_parameter_type = computation_types.StructType(type_signature) self.assertEqual(actual_type.parameter, expected_parameter_type) actual_result = test_utils.run_tensorflow(proto, operands) self.assertEqual(actual_result, expected_result)
def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type_with_names( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.bool]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32), ('c', tf.bool)]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_bool = building_blocks.Selection(tuple_ref, index=2) created_tuple = building_blocks.Struct([selected_int, selected_bool]) called_tf_block = building_blocks.Call(identity_tf_block, created_tuple) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32), ('c', tf.bool)], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertEqual(exec_lambda({ 'a': 9, 'b': 10., 'c': False }), exec_tf({ 'a': 9, 'b': 10., 'c': False }))
def test_fails_with_bad_types(self): function = computation_types.FunctionType( None, computation_types.TensorType(tf.int32)) federated = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) tuple_on_function = computation_types.StructType([federated, function]) def foo(x): # pylint: disable=unused-variable del x # Unused. with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type {int32}@CLIENTS' ): computation_wrapper_instances.tensorflow_wrapper(foo, federated) # pylint: disable=anomalous-backslash-in-string with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type \( -> int32\)' ): computation_wrapper_instances.tensorflow_wrapper(foo, function) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type placement'): computation_wrapper_instances.tensorflow_wrapper( foo, computation_types.PlacementType()) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type T'): computation_wrapper_instances.tensorflow_wrapper( foo, computation_types.AbstractType('T')) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type <{int32}@CLIENTS,\( ' '-> int32\)>'): computation_wrapper_instances.tensorflow_wrapper( foo, tuple_on_function)
def test_create_selection_does_not_cache_error_avoids_double_cache_delete( self): loop = asyncio.get_event_loop() mock_executor = mock.create_autospec(executor_base.Executor) mock_executor.create_value.side_effect = create_test_value mock_executor.create_selection.side_effect = raise_error cached_executor = caching_executor.CachingExecutor(mock_executor) value = loop.run_until_complete( cached_executor.create_value((1, 2), computation_types.StructType( (tf.int32, tf.int32)))) future1 = cached_executor.create_selection(value, 1) future2 = cached_executor.create_selection(value, 1) results = loop.run_until_complete( asyncio.gather(future1, future2, return_exceptions=True)) # Ensure create_struct was called twice on the mock (not cached and only # called once). mock_executor.create_selection.assert_has_calls([]) self.assertLen(results, 2) self.assertIsInstance(results[0], TestError) self.assertIsInstance(results[1], TestError)