def test_scope_snapshot_called_lambdas(self):
     comp = computation_building_blocks.Tuple(
         [computation_building_blocks.Data('test', tf.int32)])
     input1 = computation_building_blocks.Reference('input1',
                                                    comp.type_signature)
     first_level_call = computation_building_blocks.Call(
         computation_building_blocks.Lambda('input1', input1.type_signature,
                                            input1), comp)
     input2 = computation_building_blocks.Reference(
         'input2', first_level_call.type_signature)
     second_level_call = computation_building_blocks.Call(
         computation_building_blocks.Lambda('input2', input2.type_signature,
                                            input2), first_level_call)
     self.assertEqual(str(second_level_call),
                      '(input2 -> input2)((input1 -> input1)(<test>))')
     global_snapshot = transformations.scope_count_snapshot(
         second_level_call)
     self.assertEqual(
         global_snapshot, {
             '(input2 -> input2)': {
                 'input2': 1
             },
             '(input1 -> input1)': {
                 'input1': 1
             }
         })
  def test_execute_with_nested_lambda(self):
    int32_add = computation_building_blocks.ComputationBuildingBlock.from_proto(
        computation_impl.ComputationImpl.get_proto(
            computations.tf_computation(tf.add, [tf.int32, tf.int32])))

    curried_int32_add = computation_building_blocks.Lambda(
        'x', tf.int32,
        computation_building_blocks.Lambda(
            'y', tf.int32,
            computation_building_blocks.Call(
                int32_add,
                computation_building_blocks.Tuple(
                    [(None, computation_building_blocks.Reference(
                        'x', tf.int32)),
                     (None, computation_building_blocks.Reference(
                         'y', tf.int32))]))))

    make_10 = computation_building_blocks.ComputationBuildingBlock.from_proto(
        computation_impl.ComputationImpl.get_proto(
            computations.tf_computation(lambda: tf.constant(10))))

    add_10 = computation_building_blocks.Call(
        curried_int32_add, computation_building_blocks.Call(make_10))

    add_10_computation = computation_impl.ComputationImpl(
        add_10.proto, context_stack_impl.context_stack)

    self.assertEqual(add_10_computation(5), 15)
 def test_nested_lambda_block_overwrite_scope_snapshot(self):
     innermost_x = computation_building_blocks.Reference('x', tf.int32)
     inner_lambda = computation_building_blocks.Lambda(
         'x', tf.int32, innermost_x)
     second_x = computation_building_blocks.Reference('x', tf.int32)
     called_lambda = computation_building_blocks.Call(
         inner_lambda, second_x)
     block_input = computation_building_blocks.Reference(
         'block_in', tf.int32)
     lower_block = computation_building_blocks.Block([('x', block_input)],
                                                     called_lambda)
     second_lambda = computation_building_blocks.Lambda(
         'block_in', tf.int32, lower_block)
     third_x = computation_building_blocks.Reference('x', tf.int32)
     second_call = computation_building_blocks.Call(second_lambda, third_x)
     final_input = computation_building_blocks.Data('test_data', tf.int32)
     last_block = computation_building_blocks.Block([('x', final_input)],
                                                    second_call)
     global_snapshot = transformations.scope_count_snapshot(last_block)
     self.assertEqual(
         str(last_block),
         '(let x=test_data in (block_in -> (let x=block_in in (x -> x)(x)))(x))'
     )
     self.assertLen(global_snapshot, 4)
     self.assertEqual(global_snapshot[str(inner_lambda)], {'x': 1})
     self.assertEqual(global_snapshot[str(lower_block)], {'x': 1})
     self.assertEqual(global_snapshot[str(second_lambda)], {'block_in': 1})
     self.assertEqual(global_snapshot[str(last_block)], {'x': 1})
def create_dummy_called_federated_aggregate(accumulate_parameter_name,
                                            merge_parameter_name,
                                            report_parameter_name):
    r"""Returns a dummy called federated aggregate.

                      Call
                     /    \
  federated_aggregate      Tuple
                           |
                           [data, data, Lambda(x), Lambda(x), Lambda(x)]
                                        |          |          |
                                        data       data       data

  Args:
    accumulate_parameter_name: The name of the accumulate parameter.
    merge_parameter_name: The name of the merge parameter.
    report_parameter_name: The name of the report parameter.
  """
    value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
    value = computation_building_blocks.Data('data', value_type)
    zero = computation_building_blocks.Data('data', tf.float32)
    accumulate_type = computation_types.NamedTupleType((tf.float32, tf.int32))
    accumulate_result = computation_building_blocks.Data('data', tf.float32)
    accumulate = computation_building_blocks.Lambda(accumulate_parameter_name,
                                                    accumulate_type,
                                                    accumulate_result)
    merge_type = computation_types.NamedTupleType((tf.float32, tf.float32))
    merge_result = computation_building_blocks.Data('data', tf.float32)
    merge = computation_building_blocks.Lambda(merge_parameter_name,
                                               merge_type, merge_result)
    report_result = computation_building_blocks.Data('data', tf.bool)
    report = computation_building_blocks.Lambda(report_parameter_name,
                                                tf.float32, report_result)
    return computation_constructing_utils.create_federated_aggregate(
        value, zero, accumulate, merge, report)
    def sequence_reduce(self, value, zero, op):
        """Implements `sequence_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        value = value_impl.to_value(value, None, self._context_stack)
        zero = value_impl.to_value(zero, None, self._context_stack)
        op = value_impl.to_value(op, None, self._context_stack)
        if isinstance(value.type_signature, computation_types.SequenceType):
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        op_type_expected = type_constructors.reduction_op(
            zero.type_signature, element_type)
        if not type_utils.is_assignable_from(op_type_expected,
                                             op.type_signature):
            raise TypeError('Expected an operator of type {}, got {}.'.format(
                op_type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        op = value_impl.ValueImpl.get_comp(op)
        if isinstance(value.type_signature, computation_types.SequenceType):
            return computation_constructing_utils.create_sequence_reduce(
                value, zero, op)
        else:
            value_type = computation_types.SequenceType(element_type)
            intrinsic_type = computation_types.FunctionType((
                value_type,
                zero.type_signature,
                op.type_signature,
            ), op.type_signature.result)
            intrinsic = computation_building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type)
            ref = computation_building_blocks.Reference('arg', value_type)
            tup = computation_building_blocks.Tuple((ref, zero, op))
            call = computation_building_blocks.Call(intrinsic, tup)
            fn = computation_building_blocks.Lambda(ref.name,
                                                    ref.type_signature, call)
            fn_impl = value_impl.ValueImpl(fn, self._context_stack)
            if value.type_signature.placement is placements.SERVER:
                return self.federated_apply(fn_impl, value)
            elif value.type_signature.placement is placements.CLIENTS:
                return self.federated_map(fn_impl, value)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    value.type_signature.placement))
 def test_propogates_dependence_up_through_lambda(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', tf.int32)
   lam = computation_building_blocks.Lambda('x', tf.int32, dummy_intrinsic)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       lam, dummy_intrinsic_predicate)
   self.assertIn(lam, dependent_nodes)
Beispiel #7
0
    def test_replace_chained_federated_maps_replaces_multiple_federated_maps(
            self):
        calling_arg_type = computation_types.FederatedType(
            tf.int32, placements.CLIENTS)
        calling_arg = computation_building_blocks.Reference(
            'arg', calling_arg_type)
        arg_type = calling_arg.type_signature.member
        arg = calling_arg
        for _ in range(10):
            lam = _create_lambda_to_add_one(arg_type)
            call = _create_call_to_federated_map(lam, arg)
            arg_type = call.function.type_signature.result.member
            arg = call
        calling_lambda = computation_building_blocks.Lambda(
            calling_arg.name, calling_arg.type_signature, call)
        comp = calling_lambda
        uri = intrinsic_defs.FEDERATED_MAP.uri

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 10)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl([(1)]), [11])

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 1)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl([(1)]), [11])
Beispiel #8
0
    def test_replace_chained_federated_maps_with_different_arg_types(self):
        map_arg_type = computation_types.FederatedType(tf.int32,
                                                       placements.CLIENTS)
        map_arg = computation_building_blocks.Reference('arg_1', map_arg_type)
        inner_lambda = _create_lambda_to_cast(tf.int32, tf.float32)
        inner_call = _create_call_to_federated_map(inner_lambda, map_arg)
        outer_lambda = _create_lambda_to_add_one(
            inner_call.type_signature.member)
        outer_call = _create_call_to_federated_map(outer_lambda, inner_call)
        map_lambda = computation_building_blocks.Lambda(
            map_arg.name, map_arg.type_signature, outer_call)
        comp = map_lambda
        self.assertEqual(
            _get_number_of_intrinsics(comp, intrinsic_defs.FEDERATED_MAP.uri),
            2)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl([(1)]), [2.0])

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(
            _get_number_of_intrinsics(transformed_comp,
                                      intrinsic_defs.FEDERATED_MAP.uri), 1)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl([(1)]), [2.0])
Beispiel #9
0
    def test_replace_called_lambda_replaces_called_lambda(self):
        arg = computation_building_blocks.Reference('arg', tf.int32)
        lam = _create_lambda_to_add_one(arg.type_signature)
        call = computation_building_blocks.Call(lam, arg)
        calling_lambda = computation_building_blocks.Lambda(
            arg.name, arg.type_signature, call)
        comp = calling_lambda

        self.assertEqual(
            _get_number_of_computations(comp,
                                        computation_building_blocks.Block), 0)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl(1), 2)

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Call),
            _get_number_of_computations(comp, computation_building_blocks.Call)
            - 1)
        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Lambda),
            _get_number_of_computations(
                comp, computation_building_blocks.Lambda) - 1)
        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Block), 1)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl(1), 2)
Beispiel #10
0
    def test_replace_intrinsic_replaces_multiple_intrinsics(self):
        calling_arg = computation_building_blocks.Reference('arg', tf.int32)
        arg_type = calling_arg.type_signature
        arg = calling_arg
        for _ in range(10):
            lam = _create_lambda_to_add_one(arg_type)
            call = computation_building_blocks.Call(lam, arg)
            arg_type = call.function.type_signature.result
            arg = call
        calling_lambda = computation_building_blocks.Lambda(
            calling_arg.name, calling_arg.type_signature, call)
        comp = calling_lambda
        uri = intrinsic_defs.GENERIC_PLUS.uri
        body = lambda x: 100

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 10)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl(1), 11)

        transformed_comp = transformations.replace_intrinsic_with_callable(
            comp, uri, body, context_stack_impl.context_stack)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 0)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl(1), 100)
Beispiel #11
0
def _create_lambda_to_add_one(dtype):
    r"""Creates a computation to add `1` to an argument.

  Lambda
        \
         Call
        /    \
  Intrinsic   Tuple
             /     \
    Reference       Computation

  Args:
    dtype: The type of the argument.

  Returns:
    An instance of `computation_building_blocks.Lambda` wrapping a function that
    adds 1 to an argument.
  """
    if isinstance(dtype, computation_types.TensorType):
        dtype = dtype.dtype
    py_typecheck.check_type(dtype, tf.dtypes.DType)
    function_type = computation_types.FunctionType([dtype, dtype], dtype)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.GENERIC_PLUS.uri, function_type)
    arg = computation_building_blocks.Reference('arg', dtype)
    constant = _create_call_to_py_fn(lambda: tf.cast(tf.constant(1), dtype))
    tup = computation_building_blocks.Tuple([arg, constant])
    call = computation_building_blocks.Call(intrinsic, tup)
    return computation_building_blocks.Lambda(arg.name, arg.type_signature,
                                              call)
  def sequence_reduce(self, value, zero, op):
    """Implements `sequence_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    zero = value_impl.to_value(zero, None, self._context_stack)
    op = value_impl.to_value(op, None, self._context_stack)
    if isinstance(value.type_signature, computation_types.SequenceType):
      element_type = value.type_signature.element
    else:
      py_typecheck.check_type(value.type_signature,
                              computation_types.FederatedType)
      py_typecheck.check_type(value.type_signature.member,
                              computation_types.SequenceType)
      element_type = value.type_signature.member.element
    op_type_expected = type_constructors.reduction_op(zero.type_signature,
                                                      element_type)
    if not type_utils.is_assignable_from(op_type_expected, op.type_signature):
      raise TypeError('Expected an operator of type {}, got {}.'.format(
          str(op_type_expected), str(op.type_signature)))
    sequence_reduce_building_block = computation_building_blocks.Intrinsic(
        intrinsic_defs.SEQUENCE_REDUCE.uri,
        computation_types.FunctionType([
            computation_types.SequenceType(element_type), zero.type_signature,
            op.type_signature
        ], zero.type_signature))
    if isinstance(value.type_signature, computation_types.SequenceType):
      sequence_reduce_intrinsic = value_impl.ValueImpl(
          sequence_reduce_building_block, self._context_stack)
      return sequence_reduce_intrinsic(value, zero, op)
    else:
      federated_mapping_fn_building_block = computation_building_blocks.Lambda(
          'arg', computation_types.SequenceType(element_type),
          computation_building_blocks.Call(
              sequence_reduce_building_block,
              computation_building_blocks.Tuple([
                  computation_building_blocks.Reference(
                      'arg', computation_types.SequenceType(element_type)),
                  value_impl.ValueImpl.get_comp(zero),
                  value_impl.ValueImpl.get_comp(op)
              ])))
      federated_mapping_fn = value_impl.ValueImpl(
          federated_mapping_fn_building_block, self._context_stack)
      if value.type_signature.placement is placements.SERVER:
        return self.federated_apply(federated_mapping_fn, value)
      elif value.type_signature.placement is placements.CLIENTS:
        return self.federated_map(federated_mapping_fn, value)
      else:
        raise TypeError('Unsupported placement {}.'.format(
            str(value.type_signature.placement)))
    def test_with_block(self):
        ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor())
        loop = asyncio.get_event_loop()

        f_type = computation_types.FunctionType(tf.int32, tf.int32)
        a = computation_building_blocks.Reference(
            'a',
            computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)]))
        ret = computation_building_blocks.Block(
            [('f', computation_building_blocks.Selection(a, name='f')),
             ('x', computation_building_blocks.Selection(a, name='x'))],
            computation_building_blocks.Call(
                computation_building_blocks.Reference('f', f_type),
                computation_building_blocks.Call(
                    computation_building_blocks.Reference('f', f_type),
                    computation_building_blocks.Reference('x', tf.int32))))
        comp = computation_building_blocks.Lambda(a.name, a.type_signature,
                                                  ret)

        @computations.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        v1 = loop.run_until_complete(
            ex.create_value(comp.proto, comp.type_signature))
        v2 = loop.run_until_complete(ex.create_value(add_one))
        v3 = loop.run_until_complete(ex.create_value(10, tf.int32))
        v4 = loop.run_until_complete(
            ex.create_tuple(
                anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)])))
        v5 = loop.run_until_complete(ex.create_call(v1, v4))
        result = loop.run_until_complete(v5.compute())
        self.assertEqual(result.numpy(), 12)
 def test_basic_functionality_of_lambda_class(self):
     arg_name = 'arg'
     arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
                 ('x', tf.int32)]
     arg = computation_building_blocks.Reference(arg_name, arg_type)
     arg_f = computation_building_blocks.Selection(arg, name='f')
     arg_x = computation_building_blocks.Selection(arg, name='x')
     x = computation_building_blocks.Lambda(
         arg_name, arg_type,
         computation_building_blocks.Call(
             arg_f, computation_building_blocks.Call(arg_f, arg_x)))
     self.assertEqual(str(x.type_signature),
                      '(<f=(int32 -> int32),x=int32> -> int32)')
     self.assertEqual(x.parameter_name, arg_name)
     self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>')
     self.assertEqual(x.result.tff_repr, 'arg.f(arg.f(arg.x))')
     arg_type_repr = (
         'NamedTupleType(['
         '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), '
         '(\'x\', TensorType(tf.int32))])')
     self.assertEqual(
         repr(x), 'Lambda(\'arg\', {0}, '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format(
             arg_type_repr))
     self.assertEqual(x.tff_repr, '(arg -> arg.f(arg.f(arg.x)))')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'lambda')
     self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name)
     self.assertEqual(str(getattr(x_proto, 'lambda').result),
                      str(x.result.proto))
     self._serialize_deserialize_roundtrip_test(x)
def _create_lambda_to_dummy_intrinsic(type_spec, uri='dummy'):
    r"""Creates a lambda to call a dummy intrinsic.

  Lambda
        \
         Call
        /    \
  Intrinsic   Ref(arg)

  Args:
    type_spec: The type of the argument.
    uri: The URI of the intrinsic.

  Returns:
    A `computation_building_blocks.Lambda`.

  Raises:
    TypeError: If `type_spec` is not a `tf.dtypes.DType`.
  """
    py_typecheck.check_type(type_spec, tf.dtypes.DType)
    intrinsic_type = computation_types.FunctionType(type_spec, type_spec)
    intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type)
    arg = computation_building_blocks.Reference('arg', type_spec)
    call = computation_building_blocks.Call(intrinsic, arg)
    return computation_building_blocks.Lambda(arg.name, arg.type_signature,
                                              call)
Beispiel #16
0
def get_curried(func):
  """Returns a curried version of function `func` that takes a parameter tuple.

  For functions `func` of types <T1,T2,....,Tn> -> U, the result is a function
  of the form T1 -> (T2 -> (T3 -> .... (Tn -> U) ... )).

  NOTE: No attempt is made at avoiding naming conflicts in cases where `func`
  contains references. The arguments of the curriend function are named `argN`
  with `N` starting at 0.

  Args:
    func: A value of a functional TFF type.

  Returns:
    A value that represents the curried form of `func`.
  """
  py_typecheck.check_type(func, value_base.Value)
  py_typecheck.check_type(func.type_signature, computation_types.FunctionType)
  py_typecheck.check_type(func.type_signature.parameter,
                          computation_types.NamedTupleType)
  param_elements = anonymous_tuple.to_elements(func.type_signature.parameter)
  references = []
  for idx, (_, elem_type) in enumerate(param_elements):
    references.append(
        computation_building_blocks.Reference('arg{}'.format(idx), elem_type))
  result = computation_building_blocks.Call(
      value_impl.ValueImpl.get_comp(func),
      computation_building_blocks.Tuple(references))
  for ref in references[::-1]:
    result = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                                result)
  return value_impl.ValueImpl(result,
                              value_impl.ValueImpl.get_context_stack(func))
def construct_federated_getattr_comp(comp, name):
  """Function to construct computation for `federated_apply` of `__getattr__`.

  Constructs a `computation_building_blocks.ComputationBuildingBlock`
  which selects `name` from its argument, of type `comp.type_signature.member`,
  an instance of `computation_types.NamedTupleType`.

  Args:
    comp: Instance of `ValueImpl` or
      `computation_building_blocks.ComputationBuildingBlock` with type signature
      `computation_types.FederatedType` whose `member` attribute is of type
      `computation_types.NamedTupleType`.
    name: String name of attribute to grab.

  Returns:
    Instance of `computation_building_blocks.Lambda` which grabs attribute
      according to `name` of its argument.
  """
  py_typecheck.check_type(comp.type_signature, computation_types.FederatedType)
  py_typecheck.check_type(comp.type_signature.member,
                          computation_types.NamedTupleType)
  element_names = [
      x for x, _ in anonymous_tuple.to_elements(comp.type_signature.member)
  ]
  if name not in element_names:
    raise ValueError('The federated value {} has no element of name {}'.format(
        comp, name))
  apply_input = computation_building_blocks.Reference(
      'x', comp.type_signature.member)
  selected = computation_building_blocks.Selection(apply_input, name=name)
  apply_lambda = computation_building_blocks.Lambda(
      'x', apply_input.type_signature, selected)
  return apply_lambda
 def test_returns_string_for_comp_with_right_overhang(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     data = computation_building_blocks.Data('data', tf.int32)
     tup = computation_building_blocks.Tuple([ref, data, data, data, data])
     sel = computation_building_blocks.Selection(tup, index=0)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             sel)
     comp = computation_building_blocks.Call(fn, data)
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string,
                      '(a -> <a,data,data,data,data>[0])(data)')
     formatted_string = comp.formatted_representation()
     # pyformat: disable
     self.assertEqual(
         formatted_string, '(a -> <\n'
         '  a,\n'
         '  data,\n'
         '  data,\n'
         '  data,\n'
         '  data\n'
         '>[0])(data)')
     # pyformat: enable
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(
         structural_string, '          Call\n'
         '         /    \\\n'
         'Lambda(a)      data\n'
         '|\n'
         'Sel(0)\n'
         '|\n'
         'Tuple\n'
         '|\n'
         '[Ref(a), data, data, data, data]')
Beispiel #19
0
    def test_replace_chained_federated_maps_does_not_replace_unchained_federated_maps(
            self):
        map_arg_type = computation_types.FederatedType(tf.int32,
                                                       placements.CLIENTS)
        map_arg = computation_building_blocks.Reference('arg', map_arg_type)
        inner_lambda = _create_lambda_to_add_one(map_arg.type_signature.member)
        inner_call = _create_call_to_federated_map(inner_lambda, map_arg)
        dummy_tuple = computation_building_blocks.Tuple([inner_call])
        dummy_selection = computation_building_blocks.Selection(dummy_tuple,
                                                                index=0)
        outer_lambda = _create_lambda_to_add_one(
            inner_call.function.type_signature.result.member)
        outer_call = _create_call_to_federated_map(outer_lambda,
                                                   dummy_selection)
        map_lambda = computation_building_blocks.Lambda(
            map_arg.name, map_arg.type_signature, outer_call)
        comp = map_lambda
        uri = intrinsic_defs.FEDERATED_MAP.uri

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 2)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl([(1)]), [3])

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 2)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl([(1)]), [3])
Beispiel #20
0
  def test_flatten_fn_with_names(self, n):
    input_reference = computation_building_blocks.Reference(
        'test', [(str(k), tf.int32) for k in range(n)])
    input_fn = computation_building_blocks.Lambda(
        'test', input_reference.type_signature, input_reference)
    unnamed_type_to_add = (None, computation_types.to_type(tf.int32))
    unnamed_input_type = computation_types.NamedTupleType(
        [input_reference.type_signature, unnamed_type_to_add])
    unnamed_desired_output_type = computation_types.to_type(
        [(str(k), tf.int32) for k in range(n)] + [tf.int32])
    unnamed_desired_fn_type = computation_types.FunctionType(
        unnamed_input_type, unnamed_desired_output_type)
    unnamed_new_fn = value_utils.flatten_first_index(
        value_impl.to_value(input_fn, None, _context_stack),
        unnamed_type_to_add, _context_stack)
    self.assertEqual(
        str(unnamed_new_fn.type_signature), str(unnamed_desired_fn_type))

    named_type_to_add = ('new', tf.int32)
    named_input_type = computation_types.NamedTupleType(
        [input_reference.type_signature, named_type_to_add])
    named_types = [(str(k), tf.int32) for k in range(n)] + [('new', tf.int32)]
    named_desired_output_type = computation_types.to_type(named_types)
    named_desired_fn_type = computation_types.FunctionType(
        named_input_type, named_desired_output_type)
    new_named_fn = value_utils.flatten_first_index(
        value_impl.to_value(input_fn, None, _context_stack), named_type_to_add,
        _context_stack)
    self.assertEqual(
        str(new_named_fn.type_signature), str(named_desired_fn_type))
 def test_raises_type_error_with_nonfederated_arg(self):
     ref = computation_building_blocks.Reference('x', tf.int32)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             ref)
     arg = computation_building_blocks.Data('y', tf.int32)
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_federated_map(fn, arg)
Beispiel #22
0
 def test_raises_type_error_with_none_accumulate(self):
     value_type = computation_types.FederatedType(tf.int32,
                                                  placements.CLIENTS, False)
     value = computation_building_blocks.Data('v', value_type)
     zero = computation_building_blocks.Data('z', tf.int32)
     merge_type = computation_types.NamedTupleType((tf.int32, tf.int32))
     merge_result = computation_building_blocks.Data('m', tf.int32)
     merge = computation_building_blocks.Lambda('x', merge_type,
                                                merge_result)
     report_ref = computation_building_blocks.Reference('r', tf.int32)
     report = computation_building_blocks.Lambda(report_ref.name,
                                                 report_ref.type_signature,
                                                 report_ref)
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_federated_aggregate(
             value, zero, None, merge, report)
Beispiel #23
0
 def test_flatten_fn_comp_raises_typeerror(self):
   input_reference = computation_building_blocks.Reference(
       'test', [tf.int32] * 5)
   input_fn = computation_building_blocks.Lambda(
       'test', input_reference.type_signature, input_reference)
   type_to_add = computation_types.NamedTupleType([tf.int32])
   with self.assertRaisesRegexp(TypeError, '(Expected).*(Value)'):
     _ = value_utils.flatten_first_index(input_fn, type_to_add, _context_stack)
Beispiel #24
0
 def test_raises_type_error_with_none_value(self):
     zero = computation_building_blocks.Data('z', tf.int32)
     op_type = computation_types.NamedTupleType((tf.int32, tf.int32))
     op_result = computation_building_blocks.Data('o', tf.int32)
     op = computation_building_blocks.Lambda('x', op_type, op_result)
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_sequence_reduce(
             None, zero, op)
Beispiel #25
0
 def test_identity_lambda_executes_as_identity(self):
     lam = computation_building_blocks.Lambda(
         'x', tf.int32,
         computation_building_blocks.Reference('x', tf.int32))
     computation_impl_lambda = computation_wrapper_instances.building_block_to_computation(
         lam)
     for k in range(10):
         self.assertEqual(computation_impl_lambda(k), k)
Beispiel #26
0
 def test_converts_building_block_to_computation(self):
     lam = computation_building_blocks.Lambda(
         'x', tf.int32,
         computation_building_blocks.Reference('x', tf.int32))
     computation_impl_lambda = computation_wrapper_instances.building_block_to_computation(
         lam)
     self.assertIsInstance(computation_impl_lambda,
                           computation_impl.ComputationImpl)
Beispiel #27
0
def construct_binary_operator_with_upcast(type_signature, operator):
  """Constructs lambda upcasting its argument and applying `operator`.

  The concept of upcasting is explained further in the docstring for
  `apply_binary_operator_with_upcast`.

  Notice that since we are constructing a function here, e.g. for the body
  of an intrinsic, the function we are constructing must be reducible to
  TensorFlow. Therefore `type_signature` can only have named tuple or tensor
  type elements; that is, we cannot handle federated types here in a generic
  way.

  Args:
    type_signature: Value convertible to `computation_types.NamedTupleType`,
      with two elements, both of the same type or the second able to be upcast
      to the first, as explained in `apply_binary_operator_with_upcast`, and
      both containing only tuples and tensors in their type tree.
    operator: Callable defining the operator.

  Returns:
    A `computation_building_blocks.Lambda` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """
  py_typecheck.check_callable(operator)
  type_signature = computation_types.to_type(type_signature)
  _check_generic_operator_type(type_signature)
  ref_to_arg = computation_building_blocks.Reference('binary_operator_arg',
                                                     type_signature)

  def _pack_into_type(to_pack, type_spec):
    """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
    if isinstance(type_spec, computation_types.NamedTupleType):
      elems = anonymous_tuple.to_elements(type_spec)
      packed_elems = [(elem_name, _pack_into_type(to_pack, elem_type))
                      for elem_name, elem_type in elems]
      return computation_building_blocks.Tuple(packed_elems)
    elif isinstance(type_spec, computation_types.TensorType):
      expand_fn = computation_constructing_utils.construct_tensorflow_to_broadcast_scalar(
          to_pack.type_signature.dtype, type_spec.shape)
      return computation_building_blocks.Call(expand_fn, to_pack)

  y_ref = computation_building_blocks.Selection(ref_to_arg, index=1)
  first_arg = computation_building_blocks.Selection(ref_to_arg, index=0)

  if type_utils.are_equivalent_types(first_arg.type_signature,
                                     y_ref.type_signature):
    second_arg = y_ref
  else:
    second_arg = _pack_into_type(y_ref, first_arg.type_signature)

  fn = computation_constructing_utils.construct_tensorflow_binary_operator(
      first_arg.type_signature, operator)
  packed = computation_building_blocks.Tuple([first_arg, second_arg])
  operated = computation_building_blocks.Call(fn, packed)
  lambda_encapsulating_op = computation_building_blocks.Lambda(
      ref_to_arg.name, ref_to_arg.type_signature, operated)
  return lambda_encapsulating_op
Beispiel #28
0
 def test_returns_sequence_map(self):
     ref = computation_building_blocks.Reference('x', tf.int32)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             ref)
     arg_type = computation_types.SequenceType(tf.int32)
     arg = computation_building_blocks.Data('y', arg_type)
     comp = computation_constructing_utils.create_sequence_map(fn, arg)
     self.assertEqual(comp.tff_repr, 'sequence_map(<(x -> x),y>)')
     self.assertEqual(str(comp.type_signature), 'int32*')
Beispiel #29
0
 def test_raises_type_error_with_none_value(self):
     zero = computation_building_blocks.Data('z', tf.int32)
     accumulate_type = computation_types.NamedTupleType(
         (tf.int32, tf.int32))
     accumulate_result = computation_building_blocks.Data('a', tf.int32)
     accumulate = computation_building_blocks.Lambda(
         'x', accumulate_type, accumulate_result)
     merge_type = computation_types.NamedTupleType((tf.int32, tf.int32))
     merge_result = computation_building_blocks.Data('m', tf.int32)
     merge = computation_building_blocks.Lambda('x', merge_type,
                                                merge_result)
     report_ref = computation_building_blocks.Reference('r', tf.int32)
     report = computation_building_blocks.Lambda(report_ref.name,
                                                 report_ref.type_signature,
                                                 report_ref)
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_federated_aggregate(
             None, zero, accumulate, merge, report)
Beispiel #30
0
def zero_or_one_arg_func_to_building_block(func,
                                           parameter_name,
                                           parameter_type,
                                           context_stack,
                                           suggested_name=None):
    """Converts a zero- or one-argument `func` into a computation building block.

  Args:
    func: A function with 0 or 1 arguments that contains orchestration logic,
      i.e., that expects zero or one `values_base.Value` and returns a result
      convertible to the same.
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The TFF type of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.
      If not `None`, it must be a string.

  Returns:
    An instance of `computation_building_blocks.ComputationBuildingBlock` that
    contains the logic from `func`.

  Raises:
    ValueError: if `func` is incompatible with `parameter_type`.
  """
    py_typecheck.check_callable(func)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, six.string_types)
    parameter_type = computation_types.to_type(parameter_type)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, six.string_types)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is not None:
            result = func(
                value_impl.ValueImpl(
                    computation_building_blocks.Reference(
                        parameter_name, parameter_type), context_stack))
        else:
            result = func()
        result = value_impl.to_value(result, None, context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        if parameter_type is None:
            return result_comp
        else:
            return computation_building_blocks.Lambda(parameter_name,
                                                      parameter_type,
                                                      result_comp)