Ejemplo n.º 1
0
 def test_returns_correct_structure_with_no_unbound_references(self):
     concrete_int = building_block_factory.create_tensorflow_constant(
         tf.int32, 1)
     first_tf_id = building_block_factory.create_compiled_identity(tf.int32)
     called_tf_id = building_blocks.Call(first_tf_id, concrete_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id = building_block_factory.create_compiled_identity(
         tf.int32)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(
         block_locals,
         building_blocks.Tuple([ref_to_second_call, ref_to_second_call]))
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     self.assertEqual(tf_representing_block.type_signature,
                      block.type_signature)
     self.assertIsInstance(tf_representing_block, building_blocks.Call)
     self.assertIsInstance(tf_representing_block.function,
                           building_blocks.CompiledComputation)
     self.assertIsNone(tf_representing_block.argument)
Ejemplo n.º 2
0
 def test_returns_single_called_graph_after_resolving_multiple_variables(
         self):
     ref_to_int = building_blocks.Reference('var', tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(tf.int32)
     called_tf_id = building_blocks.Call(first_tf_id, ref_to_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id = building_block_factory.create_compiled_identity(
         tf.int32)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(block_locals, ref_to_second_call)
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     self.assertEqual(tf_representing_block.type_signature,
                      block.type_signature)
     self.assertIsInstance(tf_representing_block, building_blocks.Call)
     self.assertIsInstance(tf_representing_block.function,
                           building_blocks.CompiledComputation)
     self.assertIsInstance(tf_representing_block.argument,
                           building_blocks.Reference)
     self.assertEqual(tf_representing_block.argument.name, 'var')
Ejemplo n.º 3
0
 def __setattr__(self, name, value):
     py_typecheck.check_type(name, str)
     _check_is_optionally_federated_named_tuple(
         self, "__setattr__('{}', {})".format(name, value))
     value_comp = ValueImpl.get_comp(
         to_value(value, None, self._context_stack))
     if _is_federated_named_tuple(self):
         new_comp = building_block_factory.create_federated_setattr_call(
             self._comp, name, value_comp)
         super().__setattr__('_comp', new_comp)
         return
     named_tuple_setattr_lambda = building_block_factory.create_named_tuple_setattr_lambda(
         self._comp.type_signature, name, value_comp)
     new_comp = building_blocks.Call(named_tuple_setattr_lambda, self._comp)
     super(ValueImpl, self).__setattr__('_comp', new_comp)
Ejemplo n.º 4
0
  def test_returns_string_for_call_with_arg(self):
    fn_type = computation_types.FunctionType(tf.int32, tf.int32)
    fn = building_blocks.Reference('a', fn_type)
    arg = building_blocks.Data('data', tf.int32)
    comp = building_blocks.Call(fn, arg)

    self.assertEqual(comp.compact_representation(), 'a(data)')
    self.assertEqual(comp.formatted_representation(), 'a(data)')
    # pyformat: disable
    self.assertEqual(
        comp.structural_representation(),
        '       Call\n'
        '      /    \\\n'
        'Ref(a)      data'
    )
Ejemplo n.º 5
0
 def test_strip_placement_with_called_lambda(self):
     int_type = computation_types.TensorType(tf.int32)
     server_int_type = computation_types.at_server(int_type)
     federated_ref = building_blocks.Reference('outer', server_int_type)
     inner_federated_ref = building_blocks.Reference(
         'inner', server_int_type)
     identity_lambda = building_blocks.Lambda('inner', server_int_type,
                                              inner_federated_ref)
     before = building_blocks.Call(identity_lambda, federated_ref)
     after, modified = tree_transformations.strip_placement(before)
     self.assertTrue(modified)
     self.assert_has_no_intrinsics_nor_federated_types(after)
     type_test_utils.assert_types_identical(before.type_signature,
                                            server_int_type)
     type_test_utils.assert_types_identical(after.type_signature, int_type)
Ejemplo n.º 6
0
 def test_basic_functionality_of_lambda_class(self):
     arg_name = 'arg'
     arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
                 ('x', tf.int32)]
     arg = building_blocks.Reference(arg_name, arg_type)
     arg_f = building_blocks.Selection(arg, name='f')
     arg_x = building_blocks.Selection(arg, name='x')
     x = building_blocks.Lambda(
         arg_name, arg_type,
         building_blocks.Call(arg_f, building_blocks.Call(arg_f, arg_x)))
     self.assertEqual(str(x.type_signature),
                      '(<f=(int32 -> int32),x=int32> -> int32)')
     self.assertEqual(x.parameter_name, arg_name)
     self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>')
     self.assertEqual(x.result.compact_representation(),
                      'arg.f(arg.f(arg.x))')
     arg_type_repr = (
         'NamedTupleType(['
         '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), '
         '(\'x\', TensorType(tf.int32))])')
     self.assertEqual(
         repr(x), 'Lambda(\'arg\', {0}, '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format(
             arg_type_repr))
     self.assertEqual(x.compact_representation(),
                      '(arg -> arg.f(arg.f(arg.x)))')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'lambda')
     self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name)
     self.assertEqual(str(getattr(x_proto, 'lambda').result),
                      str(x.result.proto))
     self._serialize_deserialize_roundtrip_test(x)
Ejemplo n.º 7
0
    def test_block_lambda_block_lambda(self):
        x_ref = building_blocks.Reference('a', tf.int32)
        inner_lambda = building_blocks.Lambda('a', tf.int32, x_ref)
        called_lambda = building_blocks.Call(inner_lambda, x_ref)
        lower_block = building_blocks.Block([('a', x_ref), ('a', x_ref)],
                                            called_lambda)
        second_lambda = building_blocks.Lambda('a', tf.int32, lower_block)
        second_call = building_blocks.Call(second_lambda, x_ref)
        data = building_blocks.Data('data', tf.int32)
        last_block = building_blocks.Block([('a', data), ('a', x_ref)],
                                           second_call)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            last_block)

        self.assertEqual(
            last_block.compact_representation(),
            '(let a=data,a=a in (a -> (let a=a,a=a in (a -> a)(a)))(a))')
        self.assertEqual(
            transformed_comp.compact_representation(),
            '(let a=data,_var1=a in (_var2 -> (let _var3=_var2,_var4=_var3 in (_var5 -> _var5)(_var4)))(_var1))'
        )
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Ejemplo n.º 8
0
    def test_does_not_remove_called_lambda(self):
        fn = building_block_test_utils.create_identity_function('a', tf.int32)
        arg = building_blocks.Data('data', tf.int32)
        call = building_blocks.Call(fn, arg)
        comp = call

        transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity(
            comp)

        self.assertEqual(transformed_comp.compact_representation(),
                         comp.compact_representation())
        self.assertEqual(transformed_comp.compact_representation(),
                         '(a -> a)(data)')
        self.assertEqual(transformed_comp.type_signature, comp.type_signature)
        self.assertFalse(modified)
Ejemplo n.º 9
0
 def test_executes_correctly_after_resolving_multiple_variables(self):
     ref_to_int = building_blocks.Reference('var', tf.int32)
     first_tf_id_type = computation_types.TensorType(tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(
         first_tf_id_type)
     called_tf_id = building_blocks.Call(first_tf_id, ref_to_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id_type = computation_types.TensorType(tf.int32)
     second_tf_id = building_block_factory.create_compiled_identity(
         second_tf_id_type)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(block_locals, ref_to_second_call)
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     result_one = test_utils.run_tensorflow(
         tf_representing_block.function.proto, 1)
     self.assertEqual(result_one, 1)
     result_zero = test_utils.run_tensorflow(
         tf_representing_block.function.proto, 0)
     self.assertEqual(result_zero, 0)
    def test_broadcast_dependent_on_aggregate_fails_well(self):
        cf = test_utils.get_temperature_sensor_example()
        it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
        next_comp = test_utils.computation_to_building_block(it.next)
        top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                    next_comp.parameter_type)
        first_result = building_blocks.Call(next_comp, top_level_param)
        middle_param = building_blocks.Struct([
            building_blocks.Selection(first_result, index=0),
            building_blocks.Selection(top_level_param, index=1)
        ])
        second_result = building_blocks.Call(next_comp, middle_param)
        not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                               next_comp.parameter_type,
                                               second_result)
        not_reducible_it = iterative_process.IterativeProcess(
            it.initialize,
            computation_wrapper_instances.building_block_to_computation(
                not_reducible))

        with self.assertRaisesRegex(ValueError,
                                    'broadcast dependent on aggregate'):
            canonical_form_utils.get_canonical_form_for_iterative_process(
                not_reducible_it)
Ejemplo n.º 11
0
    def sequence_reduce(self, value, zero, op):
        """Implements `sequence_reduce` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        zero = value_impl.to_value(zero, None, self._context_stack)
        op = value_impl.to_value(op, None, self._context_stack)
        if isinstance(value.type_signature, computation_types.SequenceType):
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        op_type_expected = type_factory.reduction_op(zero.type_signature,
                                                     element_type)
        if not type_analysis.is_assignable_from(op_type_expected,
                                                op.type_signature):
            raise TypeError('Expected an operator of type {}, got {}.'.format(
                op_type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        op = value_impl.ValueImpl.get_comp(op)
        if isinstance(value.type_signature, computation_types.SequenceType):
            return value_impl.ValueImpl(
                building_block_factory.create_sequence_reduce(value, zero, op),
                self._context_stack)
        else:
            value_type = computation_types.SequenceType(element_type)
            intrinsic_type = computation_types.FunctionType((
                value_type,
                zero.type_signature,
                op.type_signature,
            ), op.type_signature.result)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type)
            ref = building_blocks.Reference('arg', value_type)
            tup = building_blocks.Tuple((ref, zero, op))
            call = building_blocks.Call(intrinsic, tup)
            fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
            fn_impl = value_impl.ValueImpl(fn, self._context_stack)
            if value.type_signature.placement in [
                    placement_literals.SERVER, placement_literals.CLIENTS
            ]:
                return self.federated_map(fn_impl, value)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    value.type_signature.placement))
Ejemplo n.º 12
0
 def __add__(self, other):
     other = to_value(other, None, self._context_stack)
     if not type_utils.are_equivalent_types(self.type_signature,
                                            other.type_signature):
         raise TypeError('Cannot add {} and {}.'.format(
             self.type_signature, other.type_signature))
     return ValueImpl(
         building_blocks.Call(
             building_blocks.Intrinsic(
                 intrinsic_defs.GENERIC_PLUS.uri,
                 computation_types.FunctionType(
                     [self.type_signature, self.type_signature],
                     self.type_signature)),
             ValueImpl.get_comp(
                 to_value([self, other], None, self._context_stack))),
         self._context_stack)
Ejemplo n.º 13
0
 def __add__(self, other):
     other = to_value(other, None, self._context_stack)
     if not self.type_signature.is_equivalent_to(other.type_signature):
         raise TypeError('Cannot add {} and {}.'.format(
             self.type_signature, other.type_signature))
     # TODO(b/159281959): Follow up and bind a reference here.
     return ValueImpl(
         building_blocks.Call(
             building_blocks.Intrinsic(
                 intrinsic_defs.GENERIC_PLUS.uri,
                 computation_types.FunctionType(
                     [self.type_signature, self.type_signature],
                     self.type_signature)),
             ValueImpl.get_comp(
                 to_value([self, other], None, self._context_stack))),
         self._context_stack)
Ejemplo n.º 14
0
def create_whimsy_called_intrinsic(parameter_name, parameter_type=tf.int32):
    r"""Returns a whimsy called intrinsic.

            Call
           /    \
  intrinsic      Ref(x)

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
    intrinsic_type = computation_types.FunctionType(parameter_type,
                                                    parameter_type)
    intrinsic = building_blocks.Intrinsic('intrinsic', intrinsic_type)
    ref = building_blocks.Reference(parameter_name, parameter_type)
    return building_blocks.Call(intrinsic, ref)
Ejemplo n.º 15
0
 def test_returns_string_for_call_with_no_arg(self):
   fn_type = computation_types.FunctionType(None, tf.int32)
   fn = building_blocks.Reference('a', fn_type)
   comp = building_blocks.Call(fn)
   compact_string = comp.compact_representation()
   self.assertEqual(compact_string, 'a()')
   formatted_string = comp.formatted_representation()
   self.assertEqual(formatted_string, 'a()')
   structural_string = comp.structural_representation()
   # pyformat: disable
   self.assertEqual(
       structural_string,
       '       Call\n'
       '      /\n'
       'Ref(a)'
   )
Ejemplo n.º 16
0
 def __setattr__(self, name, value):
     py_typecheck.check_type(name, str)
     _check_struct_or_federated_struct(self, name)
     value_comp = ValueImpl.get_comp(
         to_value(value, None, self._context_stack))
     if _is_federated_named_tuple(self):
         new_comp = building_block_factory.create_federated_setattr_call(
             self._comp, name, value_comp)
         super().__setattr__('_comp', new_comp)
         return
     named_tuple_setattr_lambda = building_block_factory.create_named_tuple_setattr_lambda(
         self.type_signature, name, value_comp)
     new_comp = building_blocks.Call(named_tuple_setattr_lambda, self._comp)
     fc_context = self._context_stack.current
     ref = fc_context.bind_computation_to_reference(new_comp)
     super().__setattr__('_comp', ref)
Ejemplo n.º 17
0
 def __add__(self, other):
     other = to_value(other, None, self._context_stack)
     if not self.type_signature.is_equivalent_to(other.type_signature):
         raise TypeError('Cannot add {} and {}.'.format(
             self.type_signature, other.type_signature))
     call = building_blocks.Call(
         building_blocks.Intrinsic(
             intrinsic_defs.GENERIC_PLUS.uri,
             computation_types.FunctionType(
                 [self.type_signature, self.type_signature],
                 self.type_signature)),
         ValueImpl.get_comp(
             to_value([self, other], None, self._context_stack)))
     fc_context = self._context_stack.current
     ref = fc_context.bind_computation_to_reference(call)
     return ValueImpl(ref, self._context_stack)
Ejemplo n.º 18
0
    def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self):
        identity_tf_block_type = computation_types.TensorType(tf.int32)
        identity_tf_block = building_block_factory.create_compiled_identity(
            identity_tf_block_type)
        int_ref = building_blocks.Reference('x', tf.int32)
        called_tf_block = building_blocks.Call(identity_tf_block, int_ref)
        lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # TODO(b/157172423): change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)
        result = test_utils.run_tensorflow(parsed.proto, 2)
        self.assertEqual(2, result)
Ejemplo n.º 19
0
 def __call__(self, *args, **kwargs):
     if not self.type_signature.is_function():
         raise SyntaxError(
             'Function-like invocation is only supported for values of functional '
             'types, but the value being invoked is of type {} that does not '
             'support invocation.'.format(self.type_signature))
     if args or kwargs:
         args = [to_value(x, None) for x in args]
         kwargs = {k: to_value(v, None) for k, v in kwargs.items()}
         arg = function_utils.pack_args(self.type_signature.parameter, args,
                                        kwargs)
         arg = to_value(arg, None).comp
     else:
         arg = None
     call = building_blocks.Call(self._comp, arg)
     ref = _bind_computation_to_reference(call, 'calling a `tff.Value`')
     return Value(ref)
Ejemplo n.º 20
0
def construct_tensorflow_calling_lambda_on_concrete_arg(
        parameter: building_blocks.Reference,
        body: building_blocks.ComputationBuildingBlock,
        concrete_arg: building_blocks.ComputationBuildingBlock):
    """Generates TensorFlow for lambda invocation with given arg, body and param.

  That is, generates TensorFlow block encapsulating the logic represented by
  invoking a function with parameter `parameter` and body `body`, with argument
  `concrete_arg`.

  Via the guarantee made in `compiled_computation_transforms.TupleCalledGraphs`,
  this function makes the claim that the computations which define
  `concrete_arg` will be executed exactly once in the generated TenosorFlow.

  Args:
    parameter: Instance of `building_blocks.Reference` defining the parameter of
      the function to be generated and invoked, as described above. After
      calling this transformation, every instance of  parameter` in `body` will
      represent a reference to `concrete_arg`.
    body: `building_blocks.ComputationBuildingBlock` representing the body of
      the function for which we are generating TensorFlow.
    concrete_arg: `building_blocks.ComputationBuildingBlock` representing the
      argument to be passed to the resulting function. `concrete_arg` will then
      be referred to by every occurrence of `parameter` in `body`. Therefore
      `concrete_arg` must have an equivalent type signature to that of
      `parameter`.

  Returns:
    A called `building_blocks.CompiledComputation`, as specified above.

  Raises:
    TypeError: If the arguments are of the wrong types, or the type signature
      of `concrete_arg` does not match that of `parameter`.
  """
    py_typecheck.check_type(parameter, building_blocks.Reference)
    py_typecheck.check_type(body, building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(concrete_arg,
                            building_blocks.ComputationBuildingBlock)
    type_analysis.check_equivalent_types(parameter.type_signature,
                                         concrete_arg.type_signature)

    encapsulating_lambda = _generate_simple_tensorflow(
        building_blocks.Lambda(parameter.name, parameter.type_signature, body))
    comp_called = _generate_simple_tensorflow(
        building_blocks.Call(encapsulating_lambda, concrete_arg))
    return comp_called
Ejemplo n.º 21
0
    def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self):
        identity_tf_block = building_block_factory.create_compiled_identity(
            tf.int32)
        int_ref = building_blocks.Reference('x', tf.int32)
        called_tf_block = building_blocks.Call(identity_tf_block, int_ref)
        lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda(2), exec_tf(2))
Ejemplo n.º 22
0
def _wrap_constant_as_value(const, context_stack):
    """Wraps the given Python constant as a `tff.Value`.

  Args:
    const: Python constant to be converted to TFF value. Anything convertible to
      Tensor via `tf.constant` can be passed in.
    context_stack: The context stack to use.

  Returns:
    An instance of `value_base.Value`.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    tf_comp, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        lambda: tf.constant(const), None, context_stack)
    compiled_comp = building_blocks.CompiledComputation(tf_comp)
    called_comp = building_blocks.Call(compiled_comp)
    return ValueImpl(called_comp, context_stack)
Ejemplo n.º 23
0
 def __call__(self, *args, **kwargs):
   if not isinstance(self._comp.type_signature,
                     computation_types.FunctionType):
     raise SyntaxError(
         'Function-like invocation is only supported for values of functional '
         'types, but the value being invoked is of type {} that does not '
         'support invocation.'.format(self._comp.type_signature))
   if args or kwargs:
     args = [to_value(x, None, self._context_stack) for x in args]
     kwargs = {
         k: to_value(v, None, self._context_stack) for k, v in kwargs.items()
     }
     arg = function_utils.pack_args(self._comp.type_signature.parameter, args,
                                    kwargs, self._context_stack.current)
     arg = ValueImpl.get_comp(to_value(arg, None, self._context_stack))
   else:
     arg = None
   return ValueImpl(building_blocks.Call(self._comp, arg), self._context_stack)
Ejemplo n.º 24
0
 def test_with_structure_replacing_federated_map(self):
   function_type = computation_types.FunctionType(tf.int32, tf.int32)
   tuple_ref = building_blocks.Reference('arg', [
       function_type,
       tf.int32,
   ])
   fn = building_blocks.Selection(tuple_ref, index=0)
   arg = building_blocks.Selection(tuple_ref, index=1)
   called_fn = building_blocks.Call(fn, arg)
   concrete_fn = building_blocks.Lambda(
       'x', tf.int32, building_blocks.Reference('x', tf.int32))
   concrete_arg = building_blocks.Data('a', tf.int32)
   arg_tuple = building_blocks.Tuple([concrete_fn, concrete_arg])
   generated_structure = building_blocks.Block([('arg', arg_tuple)], called_fn)
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       generated_structure)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
Ejemplo n.º 25
0
 def test_returns_single_called_graph_with_selection_in_result(self):
   ref_to_tuple = building_blocks.Reference('var', [tf.int32, tf.int32])
   first_tf_id = building_block_factory.create_compiled_identity(
       ref_to_tuple.type_signature)
   called_tf_id = building_blocks.Call(first_tf_id, ref_to_tuple)
   ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature)
   block_locals = [('call', called_tf_id)]
   block = building_blocks.Block(
       block_locals, building_blocks.Selection(ref_to_call, index=0))
   tf_representing_block, _ = compiler_transformations.create_tensorflow_representing_block(
       block)
   self.assertEqual(tf_representing_block.type_signature, block.type_signature)
   self.assertIsInstance(tf_representing_block, building_blocks.Call)
   self.assertIsInstance(tf_representing_block.function,
                         building_blocks.CompiledComputation)
   self.assertIsInstance(tf_representing_block.argument,
                         building_blocks.Reference)
   self.assertEqual(tf_representing_block.argument.name, 'var')
Ejemplo n.º 26
0
    def test_replaces_lambda_to_called_tf_block_with_replicated_lambda_arg_with_tf_block_of_same_type(
            self):
        sum_and_add_one = _create_compiled_computation(
            lambda x: x[0] + x[1] + 1, [tf.int32, tf.int32])
        int_ref = building_blocks.Reference('x', tf.int32)
        tuple_of_ints = building_blocks.Tuple((int_ref, int_ref))
        summed = building_blocks.Call(sum_and_add_one, tuple_of_ints)
        lambda_wrapper = building_blocks.Lambda('x', tf.int32, summed)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda(17), exec_tf(17))
Ejemplo n.º 27
0
def _create_before_and_after_broadcast_for_no_broadcast(tree):
    """Creates a before and after broadcast computations for the given `tree`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_BROADCAST` in `tree`.

  NOTE: This function does not assert that there is no
  `intrinsic_defs.FEDERATED_BROADCAST` in `tree`, the caller is expected to
  perform this check before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
    name_generator = building_block_factory.unique_name_generator(tree)

    parameter_name = next(name_generator)
    empty_tuple = building_blocks.Tuple([])
    value = building_block_factory.create_federated_value(
        empty_tuple, placements.SERVER)
    before_broadcast = building_blocks.Lambda(parameter_name,
                                              tree.type_signature.parameter,
                                              value)

    parameter_name = next(name_generator)
    type_signature = computation_types.FederatedType(
        before_broadcast.type_signature.result.member, placements.CLIENTS)
    parameter_type = computation_types.NamedTupleType(
        [tree.type_signature.parameter, type_signature])
    ref = building_blocks.Reference(parameter_name, parameter_type)
    arg = building_blocks.Selection(ref, index=0)
    call = building_blocks.Call(tree, arg)
    after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature,
                                             call)

    return before_broadcast, after_broadcast
Ejemplo n.º 28
0
def _wrap_computation_as_value(
    proto: pb.Computation,
    context_stack: context_stack_base.ContextStack) -> value_base.Value:
  """Wraps the given computation as a `tff.Value`.

  Args:
    proto: A pb.Computation.
    context_stack: The context stack to use.

  Returns:
    A `value_base.Value`.
  """
  py_typecheck.check_type(proto, pb.Computation)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
  compiled = building_blocks.CompiledComputation(proto)
  call = building_blocks.Call(compiled)
  federated_computation_context = context_stack.current
  ref = federated_computation_context.bind_computation_to_reference(call)
  return ValueImpl(ref, context_stack)
Ejemplo n.º 29
0
def _create_dummy_before_and_after_broadcast(comp):
    """Creates a before and after broadcast computations for the given `comp`.

  This function is intended to be used instead of
  `transformations.force_align_and_split_by_intrinsic` to generate dummy before
  and after computations, when there is no `intrinsic_defs.FEDERATED_BROADCAST`
  present in `comp`.

  Note: This function does not assert that there is no
  `intrinsic_defs.FEDERATED_BROADCAST` present in `comp`, the caller is expected
  to perform this check before calling this function.

  Args:
    comp: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by `transformations.force_align_and_split_by_intrinsic`.
  """
    name_generator = building_block_factory.unique_name_generator(comp)

    parameter_name = six.next(name_generator)
    empty_tuple = building_blocks.Tuple([])
    federated_value_at_server = building_block_factory.create_federated_value(
        empty_tuple, placements.SERVER)
    before_broadcast = building_blocks.Lambda(parameter_name,
                                              comp.type_signature.parameter,
                                              federated_value_at_server)

    parameter_name = six.next(name_generator)
    type_signature = computation_types.FederatedType(
        before_broadcast.type_signature.result.member, placements.CLIENTS)
    parameter_type = computation_types.NamedTupleType(
        [comp.type_signature.parameter, type_signature])
    ref = building_blocks.Reference(parameter_name, parameter_type)
    arg = building_blocks.Selection(ref, index=0)
    call = building_blocks.Call(comp, arg)
    after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature,
                                             call)

    return before_broadcast, after_broadcast
    def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self):
        identity_tf_block_type = computation_types.TensorType(tf.int32)
        identity_tf_block = building_block_factory.create_compiled_identity(
            identity_tf_block_type)
        int_ref = building_blocks.Reference('x', tf.int32)
        called_tf_block = building_blocks.Call(identity_tf_block, int_ref)
        lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # FIXME(b/157172423) change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda(2), exec_tf(2))