def test_scope_snapshot_called_lambdas(self): comp = computation_building_blocks.Tuple( [computation_building_blocks.Data('test', tf.int32)]) input1 = computation_building_blocks.Reference('input1', comp.type_signature) first_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input1', input1.type_signature, input1), comp) input2 = computation_building_blocks.Reference( 'input2', first_level_call.type_signature) second_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input2', input2.type_signature, input2), first_level_call) self.assertEqual(str(second_level_call), '(input2 -> input2)((input1 -> input1)(<test>))') global_snapshot = transformations.scope_count_snapshot( second_level_call) self.assertEqual( global_snapshot, { '(input2 -> input2)': { 'input2': 1 }, '(input1 -> input1)': { 'input1': 1 } })
def test_execute_with_nested_lambda(self): int32_add = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(tf.add, [tf.int32, tf.int32]))) curried_int32_add = computation_building_blocks.Lambda( 'x', tf.int32, computation_building_blocks.Lambda( 'y', tf.int32, computation_building_blocks.Call( int32_add, computation_building_blocks.Tuple( [(None, computation_building_blocks.Reference( 'x', tf.int32)), (None, computation_building_blocks.Reference( 'y', tf.int32))])))) make_10 = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(lambda: tf.constant(10)))) add_10 = computation_building_blocks.Call( curried_int32_add, computation_building_blocks.Call(make_10)) add_10_computation = computation_impl.ComputationImpl( add_10.proto, context_stack_impl.context_stack) self.assertEqual(add_10_computation(5), 15)
def test_nested_lambda_block_overwrite_scope_snapshot(self): innermost_x = computation_building_blocks.Reference('x', tf.int32) inner_lambda = computation_building_blocks.Lambda( 'x', tf.int32, innermost_x) second_x = computation_building_blocks.Reference('x', tf.int32) called_lambda = computation_building_blocks.Call( inner_lambda, second_x) block_input = computation_building_blocks.Reference( 'block_in', tf.int32) lower_block = computation_building_blocks.Block([('x', block_input)], called_lambda) second_lambda = computation_building_blocks.Lambda( 'block_in', tf.int32, lower_block) third_x = computation_building_blocks.Reference('x', tf.int32) second_call = computation_building_blocks.Call(second_lambda, third_x) final_input = computation_building_blocks.Data('test_data', tf.int32) last_block = computation_building_blocks.Block([('x', final_input)], second_call) global_snapshot = transformations.scope_count_snapshot(last_block) self.assertEqual( str(last_block), '(let x=test_data in (block_in -> (let x=block_in in (x -> x)(x)))(x))' ) self.assertLen(global_snapshot, 4) self.assertEqual(global_snapshot[str(inner_lambda)], {'x': 1}) self.assertEqual(global_snapshot[str(lower_block)], {'x': 1}) self.assertEqual(global_snapshot[str(second_lambda)], {'block_in': 1}) self.assertEqual(global_snapshot[str(last_block)], {'x': 1})
def create_dummy_called_federated_aggregate(accumulate_parameter_name, merge_parameter_name, report_parameter_name): r"""Returns a dummy called federated aggregate. Call / \ federated_aggregate Tuple | [data, data, Lambda(x), Lambda(x), Lambda(x)] | | | data data data Args: accumulate_parameter_name: The name of the accumulate parameter. merge_parameter_name: The name of the merge parameter. report_parameter_name: The name of the report parameter. """ value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) value = computation_building_blocks.Data('data', value_type) zero = computation_building_blocks.Data('data', tf.float32) accumulate_type = computation_types.NamedTupleType((tf.float32, tf.int32)) accumulate_result = computation_building_blocks.Data('data', tf.float32) accumulate = computation_building_blocks.Lambda(accumulate_parameter_name, accumulate_type, accumulate_result) merge_type = computation_types.NamedTupleType((tf.float32, tf.float32)) merge_result = computation_building_blocks.Data('data', tf.float32) merge = computation_building_blocks.Lambda(merge_parameter_name, merge_type, merge_result) report_result = computation_building_blocks.Data('data', tf.bool) report = computation_building_blocks.Lambda(report_parameter_name, tf.float32, report_result) return computation_constructing_utils.create_federated_aggregate( value, zero, accumulate, merge, report)
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_constructors.reduction_op( zero.type_signature, element_type) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( op_type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) op = value_impl.ValueImpl.get_comp(op) if isinstance(value.type_signature, computation_types.SequenceType): return computation_constructing_utils.create_sequence_reduce( value, zero, op) else: value_type = computation_types.SequenceType(element_type) intrinsic_type = computation_types.FunctionType(( value_type, zero.type_signature, op.type_signature, ), op.type_signature.result) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type) ref = computation_building_blocks.Reference('arg', value_type) tup = computation_building_blocks.Tuple((ref, zero, op)) call = computation_building_blocks.Call(intrinsic, tup) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, call) fn_impl = value_impl.ValueImpl(fn, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(fn_impl, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(fn_impl, value) else: raise TypeError('Unsupported placement {}.'.format( value.type_signature.placement))
def test_propogates_dependence_up_through_lambda(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) lam = computation_building_blocks.Lambda('x', tf.int32, dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( lam, dummy_intrinsic_predicate) self.assertIn(lam, dependent_nodes)
def test_replace_chained_federated_maps_replaces_multiple_federated_maps( self): calling_arg_type = computation_types.FederatedType( tf.int32, placements.CLIENTS) calling_arg = computation_building_blocks.Reference( 'arg', calling_arg_type) arg_type = calling_arg.type_signature.member arg = calling_arg for _ in range(10): lam = _create_lambda_to_add_one(arg_type) call = _create_call_to_federated_map(lam, arg) arg_type = call.function.type_signature.result.member arg = call calling_lambda = computation_building_blocks.Lambda( calling_arg.name, calling_arg.type_signature, call) comp = calling_lambda uri = intrinsic_defs.FEDERATED_MAP.uri self.assertEqual(_get_number_of_intrinsics(comp, uri), 10) comp_impl = _to_comp(comp) self.assertEqual(comp_impl([(1)]), [11]) transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 1) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl([(1)]), [11])
def test_replace_chained_federated_maps_with_different_arg_types(self): map_arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) map_arg = computation_building_blocks.Reference('arg_1', map_arg_type) inner_lambda = _create_lambda_to_cast(tf.int32, tf.float32) inner_call = _create_call_to_federated_map(inner_lambda, map_arg) outer_lambda = _create_lambda_to_add_one( inner_call.type_signature.member) outer_call = _create_call_to_federated_map(outer_lambda, inner_call) map_lambda = computation_building_blocks.Lambda( map_arg.name, map_arg.type_signature, outer_call) comp = map_lambda self.assertEqual( _get_number_of_intrinsics(comp, intrinsic_defs.FEDERATED_MAP.uri), 2) comp_impl = _to_comp(comp) self.assertEqual(comp_impl([(1)]), [2.0]) transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual( _get_number_of_intrinsics(transformed_comp, intrinsic_defs.FEDERATED_MAP.uri), 1) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl([(1)]), [2.0])
def test_replace_called_lambda_replaces_called_lambda(self): arg = computation_building_blocks.Reference('arg', tf.int32) lam = _create_lambda_to_add_one(arg.type_signature) call = computation_building_blocks.Call(lam, arg) calling_lambda = computation_building_blocks.Lambda( arg.name, arg.type_signature, call) comp = calling_lambda self.assertEqual( _get_number_of_computations(comp, computation_building_blocks.Block), 0) comp_impl = _to_comp(comp) self.assertEqual(comp_impl(1), 2) transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Call), _get_number_of_computations(comp, computation_building_blocks.Call) - 1) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Lambda), _get_number_of_computations( comp, computation_building_blocks.Lambda) - 1) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Block), 1) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl(1), 2)
def test_replace_intrinsic_replaces_multiple_intrinsics(self): calling_arg = computation_building_blocks.Reference('arg', tf.int32) arg_type = calling_arg.type_signature arg = calling_arg for _ in range(10): lam = _create_lambda_to_add_one(arg_type) call = computation_building_blocks.Call(lam, arg) arg_type = call.function.type_signature.result arg = call calling_lambda = computation_building_blocks.Lambda( calling_arg.name, calling_arg.type_signature, call) comp = calling_lambda uri = intrinsic_defs.GENERIC_PLUS.uri body = lambda x: 100 self.assertEqual(_get_number_of_intrinsics(comp, uri), 10) comp_impl = _to_comp(comp) self.assertEqual(comp_impl(1), 11) transformed_comp = transformations.replace_intrinsic_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 0) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl(1), 100)
def _create_lambda_to_add_one(dtype): r"""Creates a computation to add `1` to an argument. Lambda \ Call / \ Intrinsic Tuple / \ Reference Computation Args: dtype: The type of the argument. Returns: An instance of `computation_building_blocks.Lambda` wrapping a function that adds 1 to an argument. """ if isinstance(dtype, computation_types.TensorType): dtype = dtype.dtype py_typecheck.check_type(dtype, tf.dtypes.DType) function_type = computation_types.FunctionType([dtype, dtype], dtype) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.GENERIC_PLUS.uri, function_type) arg = computation_building_blocks.Reference('arg', dtype) constant = _create_call_to_py_fn(lambda: tf.cast(tf.constant(1), dtype)) tup = computation_building_blocks.Tuple([arg, constant]) call = computation_building_blocks.Call(intrinsic, tup) return computation_building_blocks.Lambda(arg.name, arg.type_signature, call)
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_constructors.reduction_op(zero.type_signature, element_type) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( str(op_type_expected), str(op.type_signature))) sequence_reduce_building_block = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, computation_types.FunctionType([ computation_types.SequenceType(element_type), zero.type_signature, op.type_signature ], zero.type_signature)) if isinstance(value.type_signature, computation_types.SequenceType): sequence_reduce_intrinsic = value_impl.ValueImpl( sequence_reduce_building_block, self._context_stack) return sequence_reduce_intrinsic(value, zero, op) else: federated_mapping_fn_building_block = computation_building_blocks.Lambda( 'arg', computation_types.SequenceType(element_type), computation_building_blocks.Call( sequence_reduce_building_block, computation_building_blocks.Tuple([ computation_building_blocks.Reference( 'arg', computation_types.SequenceType(element_type)), value_impl.ValueImpl.get_comp(zero), value_impl.ValueImpl.get_comp(op) ]))) federated_mapping_fn = value_impl.ValueImpl( federated_mapping_fn_building_block, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(federated_mapping_fn, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(federated_mapping_fn, value) else: raise TypeError('Unsupported placement {}.'.format( str(value.type_signature.placement)))
def test_with_block(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() f_type = computation_types.FunctionType(tf.int32, tf.int32) a = computation_building_blocks.Reference( 'a', computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)])) ret = computation_building_blocks.Block( [('f', computation_building_blocks.Selection(a, name='f')), ('x', computation_building_blocks.Selection(a, name='x'))], computation_building_blocks.Call( computation_building_blocks.Reference('f', f_type), computation_building_blocks.Call( computation_building_blocks.Reference('f', f_type), computation_building_blocks.Reference('x', tf.int32)))) comp = computation_building_blocks.Lambda(a.name, a.type_signature, ret) @computations.tf_computation(tf.int32) def add_one(x): return x + 1 v1 = loop.run_until_complete( ex.create_value(comp.proto, comp.type_signature)) v2 = loop.run_until_complete(ex.create_value(add_one)) v3 = loop.run_until_complete(ex.create_value(10, tf.int32)) v4 = loop.run_until_complete( ex.create_tuple( anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)]))) v5 = loop.run_until_complete(ex.create_call(v1, v4)) result = loop.run_until_complete(v5.compute()) self.assertEqual(result.numpy(), 12)
def test_basic_functionality_of_lambda_class(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] arg = computation_building_blocks.Reference(arg_name, arg_type) arg_f = computation_building_blocks.Selection(arg, name='f') arg_x = computation_building_blocks.Selection(arg, name='x') x = computation_building_blocks.Lambda( arg_name, arg_type, computation_building_blocks.Call( arg_f, computation_building_blocks.Call(arg_f, arg_x))) self.assertEqual(str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(x.parameter_name, arg_name) self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>') self.assertEqual(x.result.tff_repr, 'arg.f(arg.f(arg.x))') arg_type_repr = ( 'NamedTupleType([' '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), ' '(\'x\', TensorType(tf.int32))])') self.assertEqual( repr(x), 'Lambda(\'arg\', {0}, ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format( arg_type_repr)) self.assertEqual(x.tff_repr, '(arg -> arg.f(arg.f(arg.x)))') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) self.assertEqual(str(getattr(x_proto, 'lambda').result), str(x.result.proto)) self._serialize_deserialize_roundtrip_test(x)
def _create_lambda_to_dummy_intrinsic(type_spec, uri='dummy'): r"""Creates a lambda to call a dummy intrinsic. Lambda \ Call / \ Intrinsic Ref(arg) Args: type_spec: The type of the argument. uri: The URI of the intrinsic. Returns: A `computation_building_blocks.Lambda`. Raises: TypeError: If `type_spec` is not a `tf.dtypes.DType`. """ py_typecheck.check_type(type_spec, tf.dtypes.DType) intrinsic_type = computation_types.FunctionType(type_spec, type_spec) intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type) arg = computation_building_blocks.Reference('arg', type_spec) call = computation_building_blocks.Call(intrinsic, arg) return computation_building_blocks.Lambda(arg.name, arg.type_signature, call)
def get_curried(func): """Returns a curried version of function `func` that takes a parameter tuple. For functions `func` of types <T1,T2,....,Tn> -> U, the result is a function of the form T1 -> (T2 -> (T3 -> .... (Tn -> U) ... )). NOTE: No attempt is made at avoiding naming conflicts in cases where `func` contains references. The arguments of the curriend function are named `argN` with `N` starting at 0. Args: func: A value of a functional TFF type. Returns: A value that represents the curried form of `func`. """ py_typecheck.check_type(func, value_base.Value) py_typecheck.check_type(func.type_signature, computation_types.FunctionType) py_typecheck.check_type(func.type_signature.parameter, computation_types.NamedTupleType) param_elements = anonymous_tuple.to_elements(func.type_signature.parameter) references = [] for idx, (_, elem_type) in enumerate(param_elements): references.append( computation_building_blocks.Reference('arg{}'.format(idx), elem_type)) result = computation_building_blocks.Call( value_impl.ValueImpl.get_comp(func), computation_building_blocks.Tuple(references)) for ref in references[::-1]: result = computation_building_blocks.Lambda(ref.name, ref.type_signature, result) return value_impl.ValueImpl(result, value_impl.ValueImpl.get_context_stack(func))
def construct_federated_getattr_comp(comp, name): """Function to construct computation for `federated_apply` of `__getattr__`. Constructs a `computation_building_blocks.ComputationBuildingBlock` which selects `name` from its argument, of type `comp.type_signature.member`, an instance of `computation_types.NamedTupleType`. Args: comp: Instance of `ValueImpl` or `computation_building_blocks.ComputationBuildingBlock` with type signature `computation_types.FederatedType` whose `member` attribute is of type `computation_types.NamedTupleType`. name: String name of attribute to grab. Returns: Instance of `computation_building_blocks.Lambda` which grabs attribute according to `name` of its argument. """ py_typecheck.check_type(comp.type_signature, computation_types.FederatedType) py_typecheck.check_type(comp.type_signature.member, computation_types.NamedTupleType) element_names = [ x for x, _ in anonymous_tuple.to_elements(comp.type_signature.member) ] if name not in element_names: raise ValueError('The federated value {} has no element of name {}'.format( comp, name)) apply_input = computation_building_blocks.Reference( 'x', comp.type_signature.member) selected = computation_building_blocks.Selection(apply_input, name=name) apply_lambda = computation_building_blocks.Lambda( 'x', apply_input.type_signature, selected) return apply_lambda
def test_returns_string_for_comp_with_right_overhang(self): ref = computation_building_blocks.Reference('a', tf.int32) data = computation_building_blocks.Data('data', tf.int32) tup = computation_building_blocks.Tuple([ref, data, data, data, data]) sel = computation_building_blocks.Selection(tup, index=0) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, sel) comp = computation_building_blocks.Call(fn, data) compact_string = comp.compact_representation() self.assertEqual(compact_string, '(a -> <a,data,data,data,data>[0])(data)') formatted_string = comp.formatted_representation() # pyformat: disable self.assertEqual( formatted_string, '(a -> <\n' ' a,\n' ' data,\n' ' data,\n' ' data,\n' ' data\n' '>[0])(data)') # pyformat: enable structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Sel(0)\n' '|\n' 'Tuple\n' '|\n' '[Ref(a), data, data, data, data]')
def test_replace_chained_federated_maps_does_not_replace_unchained_federated_maps( self): map_arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) map_arg = computation_building_blocks.Reference('arg', map_arg_type) inner_lambda = _create_lambda_to_add_one(map_arg.type_signature.member) inner_call = _create_call_to_federated_map(inner_lambda, map_arg) dummy_tuple = computation_building_blocks.Tuple([inner_call]) dummy_selection = computation_building_blocks.Selection(dummy_tuple, index=0) outer_lambda = _create_lambda_to_add_one( inner_call.function.type_signature.result.member) outer_call = _create_call_to_federated_map(outer_lambda, dummy_selection) map_lambda = computation_building_blocks.Lambda( map_arg.name, map_arg.type_signature, outer_call) comp = map_lambda uri = intrinsic_defs.FEDERATED_MAP.uri self.assertEqual(_get_number_of_intrinsics(comp, uri), 2) comp_impl = _to_comp(comp) self.assertEqual(comp_impl([(1)]), [3]) transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 2) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl([(1)]), [3])
def test_flatten_fn_with_names(self, n): input_reference = computation_building_blocks.Reference( 'test', [(str(k), tf.int32) for k in range(n)]) input_fn = computation_building_blocks.Lambda( 'test', input_reference.type_signature, input_reference) unnamed_type_to_add = (None, computation_types.to_type(tf.int32)) unnamed_input_type = computation_types.NamedTupleType( [input_reference.type_signature, unnamed_type_to_add]) unnamed_desired_output_type = computation_types.to_type( [(str(k), tf.int32) for k in range(n)] + [tf.int32]) unnamed_desired_fn_type = computation_types.FunctionType( unnamed_input_type, unnamed_desired_output_type) unnamed_new_fn = value_utils.flatten_first_index( value_impl.to_value(input_fn, None, _context_stack), unnamed_type_to_add, _context_stack) self.assertEqual( str(unnamed_new_fn.type_signature), str(unnamed_desired_fn_type)) named_type_to_add = ('new', tf.int32) named_input_type = computation_types.NamedTupleType( [input_reference.type_signature, named_type_to_add]) named_types = [(str(k), tf.int32) for k in range(n)] + [('new', tf.int32)] named_desired_output_type = computation_types.to_type(named_types) named_desired_fn_type = computation_types.FunctionType( named_input_type, named_desired_output_type) new_named_fn = value_utils.flatten_first_index( value_impl.to_value(input_fn, None, _context_stack), named_type_to_add, _context_stack) self.assertEqual( str(new_named_fn.type_signature), str(named_desired_fn_type))
def test_raises_type_error_with_nonfederated_arg(self): ref = computation_building_blocks.Reference('x', tf.int32) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) arg = computation_building_blocks.Data('y', tf.int32) with self.assertRaises(TypeError): computation_constructing_utils.create_federated_map(fn, arg)
def test_raises_type_error_with_none_accumulate(self): value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS, False) value = computation_building_blocks.Data('v', value_type) zero = computation_building_blocks.Data('z', tf.int32) merge_type = computation_types.NamedTupleType((tf.int32, tf.int32)) merge_result = computation_building_blocks.Data('m', tf.int32) merge = computation_building_blocks.Lambda('x', merge_type, merge_result) report_ref = computation_building_blocks.Reference('r', tf.int32) report = computation_building_blocks.Lambda(report_ref.name, report_ref.type_signature, report_ref) with self.assertRaises(TypeError): computation_constructing_utils.create_federated_aggregate( value, zero, None, merge, report)
def test_flatten_fn_comp_raises_typeerror(self): input_reference = computation_building_blocks.Reference( 'test', [tf.int32] * 5) input_fn = computation_building_blocks.Lambda( 'test', input_reference.type_signature, input_reference) type_to_add = computation_types.NamedTupleType([tf.int32]) with self.assertRaisesRegexp(TypeError, '(Expected).*(Value)'): _ = value_utils.flatten_first_index(input_fn, type_to_add, _context_stack)
def test_raises_type_error_with_none_value(self): zero = computation_building_blocks.Data('z', tf.int32) op_type = computation_types.NamedTupleType((tf.int32, tf.int32)) op_result = computation_building_blocks.Data('o', tf.int32) op = computation_building_blocks.Lambda('x', op_type, op_result) with self.assertRaises(TypeError): computation_constructing_utils.create_sequence_reduce( None, zero, op)
def test_identity_lambda_executes_as_identity(self): lam = computation_building_blocks.Lambda( 'x', tf.int32, computation_building_blocks.Reference('x', tf.int32)) computation_impl_lambda = computation_wrapper_instances.building_block_to_computation( lam) for k in range(10): self.assertEqual(computation_impl_lambda(k), k)
def test_converts_building_block_to_computation(self): lam = computation_building_blocks.Lambda( 'x', tf.int32, computation_building_blocks.Reference('x', tf.int32)) computation_impl_lambda = computation_wrapper_instances.building_block_to_computation( lam) self.assertIsInstance(computation_impl_lambda, computation_impl.ComputationImpl)
def construct_binary_operator_with_upcast(type_signature, operator): """Constructs lambda upcasting its argument and applying `operator`. The concept of upcasting is explained further in the docstring for `apply_binary_operator_with_upcast`. Notice that since we are constructing a function here, e.g. for the body of an intrinsic, the function we are constructing must be reducible to TensorFlow. Therefore `type_signature` can only have named tuple or tensor type elements; that is, we cannot handle federated types here in a generic way. Args: type_signature: Value convertible to `computation_types.NamedTupleType`, with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `computation_building_blocks.Lambda` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_callable(operator) type_signature = computation_types.to_type(type_signature) _check_generic_operator_type(type_signature) ref_to_arg = computation_building_blocks.Reference('binary_operator_arg', type_signature) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if isinstance(type_spec, computation_types.NamedTupleType): elems = anonymous_tuple.to_elements(type_spec) packed_elems = [(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elems] return computation_building_blocks.Tuple(packed_elems) elif isinstance(type_spec, computation_types.TensorType): expand_fn = computation_constructing_utils.construct_tensorflow_to_broadcast_scalar( to_pack.type_signature.dtype, type_spec.shape) return computation_building_blocks.Call(expand_fn, to_pack) y_ref = computation_building_blocks.Selection(ref_to_arg, index=1) first_arg = computation_building_blocks.Selection(ref_to_arg, index=0) if type_utils.are_equivalent_types(first_arg.type_signature, y_ref.type_signature): second_arg = y_ref else: second_arg = _pack_into_type(y_ref, first_arg.type_signature) fn = computation_constructing_utils.construct_tensorflow_binary_operator( first_arg.type_signature, operator) packed = computation_building_blocks.Tuple([first_arg, second_arg]) operated = computation_building_blocks.Call(fn, packed) lambda_encapsulating_op = computation_building_blocks.Lambda( ref_to_arg.name, ref_to_arg.type_signature, operated) return lambda_encapsulating_op
def test_returns_sequence_map(self): ref = computation_building_blocks.Reference('x', tf.int32) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) arg_type = computation_types.SequenceType(tf.int32) arg = computation_building_blocks.Data('y', arg_type) comp = computation_constructing_utils.create_sequence_map(fn, arg) self.assertEqual(comp.tff_repr, 'sequence_map(<(x -> x),y>)') self.assertEqual(str(comp.type_signature), 'int32*')
def test_raises_type_error_with_none_value(self): zero = computation_building_blocks.Data('z', tf.int32) accumulate_type = computation_types.NamedTupleType( (tf.int32, tf.int32)) accumulate_result = computation_building_blocks.Data('a', tf.int32) accumulate = computation_building_blocks.Lambda( 'x', accumulate_type, accumulate_result) merge_type = computation_types.NamedTupleType((tf.int32, tf.int32)) merge_result = computation_building_blocks.Data('m', tf.int32) merge = computation_building_blocks.Lambda('x', merge_type, merge_result) report_ref = computation_building_blocks.Reference('r', tf.int32) report = computation_building_blocks.Lambda(report_ref.name, report_ref.type_signature, report_ref) with self.assertRaises(TypeError): computation_constructing_utils.create_federated_aggregate( None, zero, accumulate, merge, report)
def zero_or_one_arg_func_to_building_block(func, parameter_name, parameter_type, context_stack, suggested_name=None): """Converts a zero- or one-argument `func` into a computation building block. Args: func: A function with 0 or 1 arguments that contains orchestration logic, i.e., that expects zero or one `values_base.Value` and returns a result convertible to the same. parameter_name: The name of the parameter, or `None` if there is't any. parameter_type: The TFF type of the parameter, or `None` if there's none. context_stack: The context stack to use. suggested_name: The optional suggested name to use for the federated context that will be used to serialize this function's body (ideally the name of the underlying Python function). It might be modified to avoid conflicts. If not `None`, it must be a string. Returns: An instance of `computation_building_blocks.ComputationBuildingBlock` that contains the logic from `func`. Raises: ValueError: if `func` is incompatible with `parameter_type`. """ py_typecheck.check_callable(func) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if suggested_name is not None: py_typecheck.check_type(suggested_name, six.string_types) parameter_type = computation_types.to_type(parameter_type) if isinstance(context_stack.current, federated_computation_context.FederatedComputationContext): parent_context = context_stack.current else: parent_context = None context = federated_computation_context.FederatedComputationContext( context_stack, suggested_name=suggested_name, parent=parent_context) if parameter_name is not None: py_typecheck.check_type(parameter_name, six.string_types) parameter_name = '{}_{}'.format(context.name, str(parameter_name)) with context_stack.install(context): if parameter_type is not None: result = func( value_impl.ValueImpl( computation_building_blocks.Reference( parameter_name, parameter_type), context_stack)) else: result = func() result = value_impl.to_value(result, None, context_stack) result_comp = value_impl.ValueImpl.get_comp(result) if parameter_type is None: return result_comp else: return computation_building_blocks.Lambda(parameter_name, parameter_type, result_comp)