def test_basic_functionality_of_lambda_class(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] arg = computation_building_blocks.Reference(arg_name, arg_type) arg_f = computation_building_blocks.Selection(arg, name='f') arg_x = computation_building_blocks.Selection(arg, name='x') x = computation_building_blocks.Lambda( arg_name, arg_type, computation_building_blocks.Call( arg_f, computation_building_blocks.Call(arg_f, arg_x))) self.assertEqual(str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(x.parameter_name, arg_name) self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>') self.assertEqual(x.result.tff_repr, 'arg.f(arg.f(arg.x))') arg_type_repr = ( 'NamedTupleType([' '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), ' '(\'x\', TensorType(tf.int32))])') self.assertEqual( repr(x), 'Lambda(\'arg\', {0}, ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format( arg_type_repr)) self.assertEqual(x.tff_repr, '(arg -> arg.f(arg.f(arg.x)))') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) self.assertEqual(str(getattr(x_proto, 'lambda').result), str(x.result.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_nested_lambda_block_overwrite_scope_snapshot(self): innermost_x = computation_building_blocks.Reference('x', tf.int32) inner_lambda = computation_building_blocks.Lambda( 'x', tf.int32, innermost_x) second_x = computation_building_blocks.Reference('x', tf.int32) called_lambda = computation_building_blocks.Call( inner_lambda, second_x) block_input = computation_building_blocks.Reference( 'block_in', tf.int32) lower_block = computation_building_blocks.Block([('x', block_input)], called_lambda) second_lambda = computation_building_blocks.Lambda( 'block_in', tf.int32, lower_block) third_x = computation_building_blocks.Reference('x', tf.int32) second_call = computation_building_blocks.Call(second_lambda, third_x) final_input = computation_building_blocks.Data('test_data', tf.int32) last_block = computation_building_blocks.Block([('x', final_input)], second_call) global_snapshot = transformations.scope_count_snapshot(last_block) self.assertEqual( str(last_block), '(let x=test_data in (block_in -> (let x=block_in in (x -> x)(x)))(x))' ) self.assertLen(global_snapshot, 4) self.assertEqual(global_snapshot[str(inner_lambda)], {'x': 1}) self.assertEqual(global_snapshot[str(lower_block)], {'x': 1}) self.assertEqual(global_snapshot[str(second_lambda)], {'block_in': 1}) self.assertEqual(global_snapshot[str(last_block)], {'x': 1})
def test_with_block(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() f_type = computation_types.FunctionType(tf.int32, tf.int32) a = computation_building_blocks.Reference( 'a', computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)])) ret = computation_building_blocks.Block( [('f', computation_building_blocks.Selection(a, name='f')), ('x', computation_building_blocks.Selection(a, name='x'))], computation_building_blocks.Call( computation_building_blocks.Reference('f', f_type), computation_building_blocks.Call( computation_building_blocks.Reference('f', f_type), computation_building_blocks.Reference('x', tf.int32)))) comp = computation_building_blocks.Lambda(a.name, a.type_signature, ret) @computations.tf_computation(tf.int32) def add_one(x): return x + 1 v1 = loop.run_until_complete( ex.create_value(comp.proto, comp.type_signature)) v2 = loop.run_until_complete(ex.create_value(add_one)) v3 = loop.run_until_complete(ex.create_value(10, tf.int32)) v4 = loop.run_until_complete( ex.create_tuple( anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)]))) v5 = loop.run_until_complete(ex.create_call(v1, v4)) result = loop.run_until_complete(v5.compute()) self.assertEqual(result.numpy(), 12)
def test_scope_snapshot_called_lambdas(self): comp = computation_building_blocks.Tuple( [computation_building_blocks.Data('test', tf.int32)]) input1 = computation_building_blocks.Reference('input1', comp.type_signature) first_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input1', input1.type_signature, input1), comp) input2 = computation_building_blocks.Reference( 'input2', first_level_call.type_signature) second_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input2', input2.type_signature, input2), first_level_call) self.assertEqual(str(second_level_call), '(input2 -> input2)((input1 -> input1)(<test>))') global_snapshot = transformations.scope_count_snapshot( second_level_call) self.assertEqual( global_snapshot, { '(input2 -> input2)': { 'input2': 1 }, '(input1 -> input1)': { 'input1': 1 } })
def test_returns_string_for_comp_with_left_overhang(self): fn_type = computation_types.FunctionType(tf.int32, tf.int32) fn = computation_building_blocks.Reference('a', fn_type) proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) compiled = computation_building_blocks.CompiledComputation( proto, 'bbbbb') arg = computation_building_blocks.Call(compiled) comp = computation_building_blocks.Call(fn, arg) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'a(comp#bbbbb())') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'a(comp#bbbbb())') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' ' Ref(a) Call\n' ' /\n' 'Compiled(bbbbb)')
def test_execute_with_nested_lambda(self): int32_add = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(tf.add, [tf.int32, tf.int32]))) curried_int32_add = computation_building_blocks.Lambda( 'x', tf.int32, computation_building_blocks.Lambda( 'y', tf.int32, computation_building_blocks.Call( int32_add, computation_building_blocks.Tuple( [(None, computation_building_blocks.Reference( 'x', tf.int32)), (None, computation_building_blocks.Reference( 'y', tf.int32))])))) make_10 = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(lambda: tf.constant(10)))) add_10 = computation_building_blocks.Call( curried_int32_add, computation_building_blocks.Call(make_10)) add_10_computation = computation_impl.ComputationImpl( add_10.proto, context_stack_impl.context_stack) self.assertEqual(add_10_computation(5), 15)
def test_basic_functionality_of_call_class(self): x = computation_building_blocks.Reference( 'foo', computation_types.FunctionType(tf.int32, tf.bool)) y = computation_building_blocks.Reference('bar', tf.int32) z = computation_building_blocks.Call(x, y) self.assertEqual(str(z.type_signature), 'bool') self.assertIs(z.function, x) self.assertIs(z.argument, y) self.assertEqual( repr(z), 'Call(Reference(\'foo\', ' 'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), ' 'Reference(\'bar\', TensorType(tf.int32)))') self.assertEqual(z.tff_repr, 'foo(bar)') with self.assertRaises(TypeError): computation_building_blocks.Call(x) w = computation_building_blocks.Reference('bak', tf.float32) with self.assertRaises(TypeError): computation_building_blocks.Call(x, w) z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'call') self.assertEqual(str(z_proto.call.function), str(x.proto)) self.assertEqual(str(z_proto.call.argument), str(y.proto)) self._serialize_deserialize_roundtrip_test(z)
def test_execute_with_block(self): add_one = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(lambda x: x + 1, tf.int32))) make_10 = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(lambda: tf.constant(10)))) make_13 = computation_building_blocks.Block( [('x', computation_building_blocks.Call(make_10)), ('x', computation_building_blocks.Call( add_one, computation_building_blocks.Reference('x', tf.int32))), ('x', computation_building_blocks.Call( add_one, computation_building_blocks.Reference('x', tf.int32))), ('x', computation_building_blocks.Call( add_one, computation_building_blocks.Reference('x', tf.int32)))], computation_building_blocks.Reference('x', tf.int32)) make_13_computation = computation_impl.ComputationImpl( make_13.proto, context_stack_impl.context_stack) self.assertEqual(make_13_computation(), 13)
def transform(self, comp): if comp.index is not None: return computation_building_blocks.Call( select_graph_output(comp.source.function, index=comp.index), comp.source.argument) else: return computation_building_blocks.Call( select_graph_output(comp.source.function, name=comp.name), comp.source.argument)
def construct_federated_getitem_call(arg, idx): """Calls intrinsic `ValueImpl`, passing getitem to a federated value. The main piece of orchestration plugging __getitem__ call together with a federated value. Args: arg: Instance of `computation_building_blocks.ComputationBuildingBlock` of `computation_types.FederatedType` with member of type `computation_types.NamedTupleType` from which we wish to pick out item `idx`. idx: Index, instance of `int` or `slice` used to address the `computation_types.NamedTupleType` underlying `arg`. Returns: Returns an instance of `ValueImpl` of type `computation_types.FederatedType` of same placement as `arg`, the result of applying or mapping the appropriate `__getitem__` function, as defined by `idx`. """ py_typecheck.check_type(arg, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(idx, (int, slice)) py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) py_typecheck.check_type(arg.type_signature.member, computation_types.NamedTupleType) getitem_comp = construct_federated_getitem_comp(arg, idx) intrinsic = construct_map_or_apply(getitem_comp, arg) call = computation_building_blocks.Call( intrinsic, computation_building_blocks.Tuple([getitem_comp, arg])) return call
def construct_federated_getattr_call(arg, name): """Constructs computation building block passing getattr to federated value. Args: arg: Instance of `computation_building_blocks.ComputationBuildingBlock` of `computation_types.FederatedType` with member of type `computation_types.NamedTupleType` from which we wish to pick out item `name`. name: String name to address the `computation_types.NamedTupleType` underlying `arg`. Returns: Returns a `computation_building_blocks.Call` with type signature `computation_types.FederatedType` of same placement as `arg`, the result of applying or mapping the appropriate `__getattr__` function, as defined by `name`. """ py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(name, six.string_types) py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) py_typecheck.check_type(arg.type_signature.member, computation_types.NamedTupleType) getattr_comp = construct_federated_getattr_comp(arg, name) intrinsic = construct_map_or_apply(getattr_comp, arg) call = computation_building_blocks.Call( intrinsic, computation_building_blocks.Tuple([getattr_comp, arg])) return call
def _create_lambda_to_add_one(dtype): r"""Creates a computation to add `1` to an argument. Lambda \ Call / \ Intrinsic Tuple / \ Reference Computation Args: dtype: The type of the argument. Returns: An instance of `computation_building_blocks.Lambda` wrapping a function that adds 1 to an argument. """ if isinstance(dtype, computation_types.TensorType): dtype = dtype.dtype py_typecheck.check_type(dtype, tf.dtypes.DType) function_type = computation_types.FunctionType([dtype, dtype], dtype) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.GENERIC_PLUS.uri, function_type) arg = computation_building_blocks.Reference('arg', dtype) constant = _create_call_to_py_fn(lambda: tf.cast(tf.constant(1), dtype)) tup = computation_building_blocks.Tuple([arg, constant]) call = computation_building_blocks.Call(intrinsic, tup) return computation_building_blocks.Lambda(arg.name, arg.type_signature, call)
def get_curried(func): """Returns a curried version of function `func` that takes a parameter tuple. For functions `func` of types <T1,T2,....,Tn> -> U, the result is a function of the form T1 -> (T2 -> (T3 -> .... (Tn -> U) ... )). NOTE: No attempt is made at avoiding naming conflicts in cases where `func` contains references. The arguments of the curriend function are named `argN` with `N` starting at 0. Args: func: A value of a functional TFF type. Returns: A value that represents the curried form of `func`. """ py_typecheck.check_type(func, value_base.Value) py_typecheck.check_type(func.type_signature, computation_types.FunctionType) py_typecheck.check_type(func.type_signature.parameter, computation_types.NamedTupleType) param_elements = anonymous_tuple.to_elements(func.type_signature.parameter) references = [] for idx, (_, elem_type) in enumerate(param_elements): references.append( computation_building_blocks.Reference('arg{}'.format(idx), elem_type)) result = computation_building_blocks.Call( value_impl.ValueImpl.get_comp(func), computation_building_blocks.Tuple(references)) for ref in references[::-1]: result = computation_building_blocks.Lambda(ref.name, ref.type_signature, result) return value_impl.ValueImpl(result, value_impl.ValueImpl.get_context_stack(func))
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_constructors.reduction_op( zero.type_signature, element_type) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( op_type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) op = value_impl.ValueImpl.get_comp(op) if isinstance(value.type_signature, computation_types.SequenceType): return computation_constructing_utils.create_sequence_reduce( value, zero, op) else: value_type = computation_types.SequenceType(element_type) intrinsic_type = computation_types.FunctionType(( value_type, zero.type_signature, op.type_signature, ), op.type_signature.result) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type) ref = computation_building_blocks.Reference('arg', value_type) tup = computation_building_blocks.Tuple((ref, zero, op)) call = computation_building_blocks.Call(intrinsic, tup) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, call) fn_impl = value_impl.ValueImpl(fn, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(fn_impl, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(fn_impl, value) else: raise TypeError('Unsupported placement {}.'.format( value.type_signature.placement))
def test_replace_called_lambda_replaces_called_lambda(self): arg = computation_building_blocks.Reference('arg', tf.int32) lam = _create_lambda_to_add_one(arg.type_signature) call = computation_building_blocks.Call(lam, arg) calling_lambda = computation_building_blocks.Lambda( arg.name, arg.type_signature, call) comp = calling_lambda self.assertEqual( _get_number_of_computations(comp, computation_building_blocks.Block), 0) comp_impl = _to_comp(comp) self.assertEqual(comp_impl(1), 2) transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Call), _get_number_of_computations(comp, computation_building_blocks.Call) - 1) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Lambda), _get_number_of_computations( comp, computation_building_blocks.Lambda) - 1) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Block), 1) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl(1), 2)
def _create_call_to_federated_map(fn, arg): r"""Creates a computation to call a federated map. Call / \ Intrinsic Tuple / \ Computation Computation Args: fn: An instance of a functional `computation_building_blocks.ComputationBuildingBlock` to use as the map function. arg: An instance of `computation_building_blocks.ComputationBuildingBlock` to use as the map argument. Returns: An instance of `computation_building_blocks.Call` wrapping the federated map computation. """ py_typecheck.check_type( fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) federated_type = computation_types.FederatedType(fn.type_signature.result, placements.CLIENTS) function_type = computation_types.FunctionType( [fn.type_signature, arg.type_signature], federated_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, function_type) tup = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, tup)
def test_replace_intrinsic_replaces_multiple_intrinsics(self): calling_arg = computation_building_blocks.Reference('arg', tf.int32) arg_type = calling_arg.type_signature arg = calling_arg for _ in range(10): lam = _create_lambda_to_add_one(arg_type) call = computation_building_blocks.Call(lam, arg) arg_type = call.function.type_signature.result arg = call calling_lambda = computation_building_blocks.Lambda( calling_arg.name, calling_arg.type_signature, call) comp = calling_lambda uri = intrinsic_defs.GENERIC_PLUS.uri body = lambda x: 100 self.assertEqual(_get_number_of_intrinsics(comp, uri), 10) comp_impl = _to_comp(comp) self.assertEqual(comp_impl(1), 11) transformed_comp = transformations.replace_intrinsic_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 0) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl(1), 100)
def create_chained_calls(functions, arg): r"""Creates a chain of `n` calls. Call / \ Comp ... \ Call / \ Comp Comp The first functional computation in `functions` must have a parameter type that is assignable from the type of `arg`, each other functional computation in `functions` must have a parameter type that is assignable from the previous functional computations result type. Args: functions: A Python list of functional computations. arg: A `computation_building_blocks.ComputationBuildingBlock`. Returns: A `computation_building_blocks.Call`. """ for fn in functions: if not type_utils.is_assignable_from(fn.parameter_type, arg.type_signature): raise TypeError( 'The parameter of the function is of type {}, and the argument is of ' 'an incompatible type {}.'.format( str(fn.parameter_type), str(arg.type_signature))) call = computation_building_blocks.Call(fn, arg) arg = call return call
def transform(self, comp): if not self.should_transform(comp): return comp, False return computation_building_blocks.Call( select_graph_output(comp.source.function, index=comp.index, name=comp.name), comp.source.argument), True
def test_returns_string_for_comp_with_right_overhang(self): ref = computation_building_blocks.Reference('a', tf.int32) data = computation_building_blocks.Data('data', tf.int32) tup = computation_building_blocks.Tuple([ref, data, data, data, data]) sel = computation_building_blocks.Selection(tup, index=0) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, sel) comp = computation_building_blocks.Call(fn, data) compact_string = comp.compact_representation() self.assertEqual(compact_string, '(a -> <a,data,data,data,data>[0])(data)') formatted_string = comp.formatted_representation() # pyformat: disable self.assertEqual( formatted_string, '(a -> <\n' ' a,\n' ' data,\n' ' data,\n' ' data,\n' ' data\n' '>[0])(data)') # pyformat: enable structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Sel(0)\n' '|\n' 'Tuple\n' '|\n' '[Ref(a), data, data, data, data]')
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_constructors.reduction_op(zero.type_signature, element_type) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( str(op_type_expected), str(op.type_signature))) sequence_reduce_building_block = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, computation_types.FunctionType([ computation_types.SequenceType(element_type), zero.type_signature, op.type_signature ], zero.type_signature)) if isinstance(value.type_signature, computation_types.SequenceType): sequence_reduce_intrinsic = value_impl.ValueImpl( sequence_reduce_building_block, self._context_stack) return sequence_reduce_intrinsic(value, zero, op) else: federated_mapping_fn_building_block = computation_building_blocks.Lambda( 'arg', computation_types.SequenceType(element_type), computation_building_blocks.Call( sequence_reduce_building_block, computation_building_blocks.Tuple([ computation_building_blocks.Reference( 'arg', computation_types.SequenceType(element_type)), value_impl.ValueImpl.get_comp(zero), value_impl.ValueImpl.get_comp(op) ]))) federated_mapping_fn = value_impl.ValueImpl( federated_mapping_fn_building_block, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(federated_mapping_fn, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(federated_mapping_fn, value) else: raise TypeError('Unsupported placement {}.'.format( str(value.type_signature.placement)))
def _create_lambda_to_dummy_intrinsic(type_spec, uri='dummy'): r"""Creates a lambda to call a dummy intrinsic. Lambda \ Call / \ Intrinsic Ref(arg) Args: type_spec: The type of the argument. uri: The URI of the intrinsic. Returns: A `computation_building_blocks.Lambda`. Raises: TypeError: If `type_spec` is not a `tf.dtypes.DType`. """ py_typecheck.check_type(type_spec, tf.dtypes.DType) intrinsic_type = computation_types.FunctionType(type_spec, type_spec) intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type) arg = computation_building_blocks.Reference('arg', type_spec) call = computation_building_blocks.Call(intrinsic, arg) return computation_building_blocks.Lambda(arg.name, arg.type_signature, call)
def create_sequence_map(fn, arg): r"""Creates a called sequence map. Call / \ Intrinsic Tuple | [Comp, Comp] Args: fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the function. arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the argument. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) result_type = computation_types.SequenceType(fn.type_signature.result) intrinsic_type = computation_types.FunctionType( (fn.type_signature, arg.type_signature), result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type) values = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, values)
def construct_map_or_apply(fn, arg): """Injects intrinsic to allow application of `fn` to federated `arg`. Args: fn: An instance of `computation_building_blocks.ComputationBuildingBlock` of functional type to be wrapped with intrinsic in order to call on `arg`. arg: `computation_building_blocks.ComputationBuildingBlock` instance of federated type for which to construct intrinsic in order to call `fn` on `arg`. `member` of `type_signature` of `arg` must be assignable to `parameter` of `type_signature` of `fn`. Returns: Returns a `computation_building_blocks.Intrinsic` which can call `fn` on `arg`. """ py_typecheck.check_type(fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(arg, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) type_utils.check_assignable_from(fn.type_signature.parameter, arg.type_signature.member) if arg.type_signature.placement == placement_literals.SERVER: result_type = computation_types.FederatedType(fn.type_signature.result, arg.type_signature.placement, arg.type_signature.all_equal) intrinsic_type = computation_types.FunctionType( [fn.type_signature, arg.type_signature], result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_type) tup = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, tup) elif arg.type_signature.placement == placement_literals.CLIENTS: return create_federated_map(fn, arg)
def create_federated_value(value, placement): r"""Creates a called federated value. Call / \ Intrinsic Comp Args: value: A `computation_building_blocks.ComputationBuildingBlock` to use as the value. placement: A `placement_literals.PlacementLiteral` to use as the placement. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) if placement is placement_literals.CLIENTS: uri = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri elif placement is placement_literals.SERVER: uri = intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri else: raise TypeError('Unsupported placement {}.'.format(placement)) result_type = computation_types.FederatedType(value.type_signature, placement, True) intrinsic_type = computation_types.FunctionType(value.type_signature, result_type) intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type) return computation_building_blocks.Call(intrinsic, value)
def create_federated_sum(value): r"""Creates a called federated sum. Call / \ Intrinsic Comp Args: value: A `computation_building_blocks.ComputationBuildingBlock` to use as the value. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) result_type = computation_types.FederatedType(value.type_signature.member, placement_literals.SERVER, True) intrinsic_type = computation_types.FunctionType(value.type_signature, result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_SUM.uri, intrinsic_type) return computation_building_blocks.Call(intrinsic, value)
def create_federated_map(fn, arg): r"""Creates a called federated map. Call / \ Intrinsic Tuple | [Comp, Comp] Args: fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the function. arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the argument. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) result_type = computation_types.FederatedType(fn.type_signature.result, placement_literals.CLIENTS, False) intrinsic_type = computation_types.FunctionType( (fn.type_signature, arg.type_signature), result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type) values = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, values)
def _create_zip_two_values(value): r"""Creates a called federated zip with two values. Call / \ Intrinsic Tuple | [Comp1, Comp2] Notice that this function will drop any names associated to the two-tuple it is processing. This is necessary due to the type signature of the underlying federated zip intrinsic, `<T@P,U@P>-><T,U>@P`. Keeping names here would violate this type signature. The names are cached at a higher level than this function, and appended to the resulting tuple in a single call to `federated_map` or `federated_apply` before the resulting structure is sent back to the caller. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing exactly two elements. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain exactly two elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) length = len(named_type_signatures) if length != 2: raise ValueError( 'Expected a value with exactly two elements, received {} elements.' .format(named_type_signatures)) placement = value[0].type_signature.placement if placement is placement_literals.CLIENTS: uri = intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri all_equal = False elif placement is placement_literals.SERVER: uri = intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri all_equal = True else: raise TypeError('Unsupported placement {}.'.format(placement)) elements = [] for _, type_signature in named_type_signatures: federated_type = computation_types.FederatedType( type_signature.member, placement, all_equal) elements.append((None, federated_type)) parameter_type = computation_types.NamedTupleType(elements) result_type = computation_types.FederatedType( [(None, e.member) for _, e in named_type_signatures], placement, all_equal) intrinsic_type = computation_types.FunctionType(parameter_type, result_type) intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type) return computation_building_blocks.Call(intrinsic, value)
def _transform(comp): """Returns a new transformed computation or `comp`.""" if not _should_transform(comp): return comp, False def _create_block_to_chained_calls(comps): r"""Constructs a transformed block computation from `comps`. Block / \ [fn=Tuple] Lambda(arg) | \ [Comp(y), Comp(x)] Call / \ Sel(1) Call / / \ Ref(fn) Sel(0) Ref(arg) / Ref(fn) (let fn=<y, x> in (arg -> fn[1](fn[0](arg))) Args: comps: a Python list of computations. Returns: A `computation_building_blocks.Block`. """ functions = computation_building_blocks.Tuple(comps) fn_ref = computation_building_blocks.Reference( 'fn', functions.type_signature) arg_type = comps[0].type_signature.parameter arg_ref = computation_building_blocks.Reference('arg', arg_type) arg = arg_ref for index, _ in enumerate(comps): fn_sel = computation_building_blocks.Selection(fn_ref, index=index) call = computation_building_blocks.Call(fn_sel, arg) arg = call lam = computation_building_blocks.Lambda(arg_ref.name, arg_ref.type_signature, call) return computation_building_blocks.Block([('fn', functions)], lam) block = _create_block_to_chained_calls(( comp.argument[1].argument[0], comp.argument[0], )) arg = computation_building_blocks.Tuple([ block, comp.argument[1].argument[1], ]) intrinsic_type = computation_types.FunctionType( arg.type_signature, comp.function.type_signature.result) intrinsic = computation_building_blocks.Intrinsic( comp.function.uri, intrinsic_type) transformed_comp = computation_building_blocks.Call(intrinsic, arg) return transformed_comp, True
def construct_binary_operator_with_upcast(type_signature, operator): """Constructs lambda upcasting its argument and applying `operator`. The concept of upcasting is explained further in the docstring for `apply_binary_operator_with_upcast`. Notice that since we are constructing a function here, e.g. for the body of an intrinsic, the function we are constructing must be reducible to TensorFlow. Therefore `type_signature` can only have named tuple or tensor type elements; that is, we cannot handle federated types here in a generic way. Args: type_signature: Value convertible to `computation_types.NamedTupleType`, with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `computation_building_blocks.Lambda` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_callable(operator) type_signature = computation_types.to_type(type_signature) _check_generic_operator_type(type_signature) ref_to_arg = computation_building_blocks.Reference('binary_operator_arg', type_signature) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if isinstance(type_spec, computation_types.NamedTupleType): elems = anonymous_tuple.to_elements(type_spec) packed_elems = [(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elems] return computation_building_blocks.Tuple(packed_elems) elif isinstance(type_spec, computation_types.TensorType): expand_fn = computation_constructing_utils.construct_tensorflow_to_broadcast_scalar( to_pack.type_signature.dtype, type_spec.shape) return computation_building_blocks.Call(expand_fn, to_pack) y_ref = computation_building_blocks.Selection(ref_to_arg, index=1) first_arg = computation_building_blocks.Selection(ref_to_arg, index=0) if type_utils.are_equivalent_types(first_arg.type_signature, y_ref.type_signature): second_arg = y_ref else: second_arg = _pack_into_type(y_ref, first_arg.type_signature) fn = computation_constructing_utils.construct_tensorflow_binary_operator( first_arg.type_signature, operator) packed = computation_building_blocks.Tuple([first_arg, second_arg]) operated = computation_building_blocks.Call(fn, packed) lambda_encapsulating_op = computation_building_blocks.Lambda( ref_to_arg.name, ref_to_arg.type_signature, operated) return lambda_encapsulating_op