def test_propogates_dependence_up_through_call(self):
   dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32)
   ref_to_x = building_blocks.Reference('x', tf.int32)
   identity_lambda = building_blocks.Lambda('x', tf.int32, ref_to_x)
   called_lambda = building_blocks.Call(identity_lambda, dummy_intrinsic)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       called_lambda, dummy_intrinsic_predicate)
   self.assertIn(called_lambda, dependent_nodes)
 def test_returns_string_for_intrinsic(self):
     comp = building_blocks.Intrinsic('intrinsic', tf.int32)
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, 'intrinsic')
     formatted_string = comp.formatted_representation()
     self.assertEqual(formatted_string, 'intrinsic')
     structural_string = comp.structural_representation()
     self.assertEqual(structural_string, 'intrinsic')
 def test_passes_with_federated_map(self):
   intrinsic = building_blocks.Intrinsic(
       intrinsic_defs.FEDERATED_MAP.uri,
       computation_types.FunctionType([
           computation_types.FunctionType(tf.int32, tf.float32),
           computation_types.FederatedType(tf.int32, placements.CLIENTS)
       ], computation_types.FederatedType(tf.float32, placements.CLIENTS)))
   tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
 def test_propogates_dependence_up_through_lambda(self):
     type_signature = computation_types.TensorType(tf.int32)
     whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic',
                                                  type_signature)
     lam = building_blocks.Lambda('x', tf.int32, whimsy_intrinsic)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         lam, whimsy_intrinsic_predicate)
     self.assertIn(lam, dependent_nodes)
 def test_propogates_dependence_up_through_selection(self):
     type_signature = computation_types.StructType([tf.int32])
     whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic',
                                                  type_signature)
     selection = building_blocks.Selection(whimsy_intrinsic, index=0)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         selection, whimsy_intrinsic_predicate)
     self.assertIn(selection, dependent_nodes)
Exemple #6
0
 def test_propogates_dependence_up_through_tuple(self):
     dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic',
                                                 tf.int32)
     integer_reference = building_blocks.Reference('int', tf.int32)
     tup = building_blocks.Tuple([integer_reference, dummy_intrinsic])
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         tup, dummy_intrinsic_predicate)
     self.assertIn(tup, dependent_nodes)
Exemple #7
0
  def sequence_reduce(self, value, zero, op):
    """Implements `sequence_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    zero = value_impl.to_value(zero, None, self._context_stack)
    op = value_impl.to_value(op, None, self._context_stack)
    if isinstance(value.type_signature, computation_types.SequenceType):
      element_type = value.type_signature.element
    else:
      py_typecheck.check_type(value.type_signature,
                              computation_types.FederatedType)
      py_typecheck.check_type(value.type_signature.member,
                              computation_types.SequenceType)
      element_type = value.type_signature.member.element
    op_type_expected = type_constructors.reduction_op(zero.type_signature,
                                                      element_type)
    if not type_utils.is_assignable_from(op_type_expected, op.type_signature):
      raise TypeError('Expected an operator of type {}, got {}.'.format(
          op_type_expected, op.type_signature))

    value = value_impl.ValueImpl.get_comp(value)
    zero = value_impl.ValueImpl.get_comp(zero)
    op = value_impl.ValueImpl.get_comp(op)
    if isinstance(value.type_signature, computation_types.SequenceType):
      return computation_constructing_utils.create_sequence_reduce(
          value, zero, op)
    else:
      value_type = computation_types.SequenceType(element_type)
      intrinsic_type = computation_types.FunctionType((
          value_type,
          zero.type_signature,
          op.type_signature,
      ), op.type_signature.result)
      intrinsic = building_blocks.Intrinsic(intrinsic_defs.SEQUENCE_REDUCE.uri,
                                            intrinsic_type)
      ref = building_blocks.Reference('arg', value_type)
      tup = building_blocks.Tuple((ref, zero, op))
      call = building_blocks.Call(intrinsic, tup)
      fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
      fn_impl = value_impl.ValueImpl(fn, self._context_stack)
      if value.type_signature.placement is placements.SERVER:
        return self.federated_apply(fn_impl, value)
      elif value.type_signature.placement is placements.CLIENTS:
        return self.federated_map(fn_impl, value)
      else:
        raise TypeError('Unsupported placement {}.'.format(
            value.type_signature.placement))
 def test_propogates_dependence_up_through_block_locals(self):
   type_signature = computation_types.TensorType(tf.int32)
   dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic',
                                               type_signature)
   integer_reference = building_blocks.Reference('int', tf.int32)
   block = building_blocks.Block([('x', dummy_intrinsic)], integer_reference)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       block, dummy_intrinsic_predicate)
   self.assertIn(block, dependent_nodes)
 def test_propogates_dependence_up_through_block_result(self):
     dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic',
                                                 tf.int32)
     integer_reference = building_blocks.Reference('int', tf.int32)
     block = building_blocks.Block([('x', integer_reference)],
                                   dummy_intrinsic)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         block, dummy_intrinsic_predicate)
     self.assertIn(block, dependent_nodes)
Exemple #10
0
 def test_propogates_dependence_up_through_tuple(self):
     type_signature = computation_types.TensorType(tf.int32)
     whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic',
                                                  type_signature)
     integer_reference = building_blocks.Reference('int', tf.int32)
     tup = building_blocks.Struct([integer_reference, whimsy_intrinsic])
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         tup, whimsy_intrinsic_predicate)
     self.assertIn(tup, dependent_nodes)
  def test_raises_with_federated_mean(self):
    intrinsic = building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_MEAN.uri,
        computation_types.FunctionType(
            computation_types.FederatedType(tf.int32, placements.CLIENTS),
            computation_types.FederatedType(tf.int32, placements.SERVER)))

    with self.assertRaisesRegex(ValueError, intrinsic.compact_representation()):
      tree_analysis.check_contains_only_reducible_intrinsics(intrinsic)
Exemple #12
0
 def test_propogates_dependence_up_through_call(self):
     type_signature = computation_types.TensorType(tf.int32)
     whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic',
                                                  type_signature)
     ref_to_x = building_blocks.Reference('x', tf.int32)
     identity_lambda = building_blocks.Lambda('x', tf.int32, ref_to_x)
     called_lambda = building_blocks.Call(identity_lambda, whimsy_intrinsic)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         called_lambda, whimsy_intrinsic_predicate)
     self.assertIn(called_lambda, dependent_nodes)
Exemple #13
0
 def test_passes_noarg_lambda(self):
     lam = building_blocks.Lambda(None, None,
                                  building_blocks.Data('a', tf.int32))
     fed_int_type = computation_types.FederatedType(tf.int32,
                                                    placements.SERVER)
     fed_eval = building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_EVAL_AT_SERVER.uri,
         computation_types.FunctionType(lam.type_signature, fed_int_type))
     called_eval = building_blocks.Call(fed_eval, lam)
     tree_transformations.strip_placement(called_eval)
Exemple #14
0
        def computation_returning_sum(x):
            tuple_containing_intrinsic = [
                building_blocks.Intrinsic(
                    'federated_sum',
                    computation_types.FunctionType(
                        client_val_type,
                        computation_types.FederatedType(
                            client_val_type.member, placements.SERVER))), x
            ]

            return tuple_containing_intrinsic[0]
 def test_passes_with_federated_map(self):
     intrinsic = building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_MAP.uri,
         computation_types.FunctionType([
             computation_types.FunctionType(tf.int32, tf.float32),
             computation_types.FederatedType(tf.int32,
                                             placement_literals.CLIENTS)
         ],
                                        computation_types.FederatedType(
                                            tf.float32,
                                            placement_literals.CLIENTS)))
     tree_analysis.check_contains_only_reducible_intrinsics(intrinsic)
 def test_basic_intrinsic_functionality_plus_canonical_typecheck(self):
   x = building_blocks.Intrinsic(
       'generic_plus',
       computation_types.FunctionType([tf.int32, tf.int32], tf.int32))
   self.assertEqual(str(x.type_signature), '(<int32,int32> -> int32)')
   self.assertEqual(x.uri, 'generic_plus')
   self.assertEqual(x.compact_representation(), 'generic_plus')
   x_proto = x.proto
   deserialized_type = type_serialization.deserialize_type(x_proto.type)
   x.type_signature.check_assignable_from(deserialized_type)
   self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic')
   self.assertEqual(x_proto.intrinsic.uri, x.uri)
   self._serialize_deserialize_roundtrip_test(x)
  def test_propogates_dependence_into_binding_to_reference(self):
    fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
    ref_to_x = building_blocks.Reference('x', fed_type)
    federated_zero = building_blocks.Intrinsic(intrinsic_defs.GENERIC_ZERO.uri,
                                               fed_type)

    def federated_zero_predicate(x):
      return x.is_intrinsic() and x.uri == intrinsic_defs.GENERIC_ZERO.uri

    block = building_blocks.Block([('x', federated_zero)], ref_to_x)
    dependent_nodes = tree_analysis.extract_nodes_consuming(
        block, federated_zero_predicate)
    self.assertIn(ref_to_x, dependent_nodes)
Exemple #18
0
 def test_raises_disallowed_intrinsic(self):
     fed_ref = building_blocks.Reference(
         'x', computation_types.FederatedType(tf.int32, placements.SERVER))
     broadcaster = building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_BROADCAST.uri,
         computation_types.FunctionType(
             fed_ref.type_signature,
             computation_types.FederatedType(fed_ref.type_signature.member,
                                             placements.CLIENTS,
                                             all_equal=True)))
     called_broadcast = building_blocks.Call(broadcaster, fed_ref)
     with self.assertRaises(ValueError):
         tree_transformations.strip_placement(called_broadcast)
Exemple #19
0
 def __add__(self, other):
     other = to_value(other, None)
     if not self.type_signature.is_equivalent_to(other.type_signature):
         raise TypeError('Cannot add {} and {}.'.format(
             self.type_signature, other.type_signature))
     call = building_blocks.Call(
         building_blocks.Intrinsic(
             intrinsic_defs.GENERIC_PLUS.uri,
             computation_types.FunctionType(
                 [self.type_signature, self.type_signature],
                 self.type_signature)),
         to_value([self, other], None).comp)
     ref = _bind_computation_to_reference(call, 'adding a tff.Value')
     return Value(ref)
Exemple #20
0
 def __add__(self, other):
   other = to_value(other, None, self._context_stack)
   if not self.type_signature.is_equivalent_to(other.type_signature):
     raise TypeError('Cannot add {} and {}.'.format(self.type_signature,
                                                    other.type_signature))
   call = building_blocks.Call(
       building_blocks.Intrinsic(
           intrinsic_defs.GENERIC_PLUS.uri,
           computation_types.FunctionType(
               [self.type_signature, self.type_signature],
               self.type_signature)),
       ValueImpl.get_comp(to_value([self, other], None, self._context_stack)))
   fc_context = self._context_stack.current
   ref = fc_context.bind_computation_to_reference(call)
   return ValueImpl(ref, self._context_stack)
 def test_basic_functionality_of_intrinsic_class(self):
   x = building_blocks.Intrinsic(
       'add_one', computation_types.FunctionType(tf.int32, tf.int32))
   self.assertEqual(str(x.type_signature), '(int32 -> int32)')
   self.assertEqual(x.uri, 'add_one')
   self.assertEqual(
       repr(x), 'Intrinsic(\'add_one\', '
       'FunctionType(TensorType(tf.int32), TensorType(tf.int32)))')
   self.assertEqual(x.compact_representation(), 'add_one')
   x_proto = x.proto
   self.assertEqual(
       type_serialization.deserialize_type(x_proto.type), x.type_signature)
   self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic')
   self.assertEqual(x_proto.intrinsic.uri, x.uri)
   self._serialize_deserialize_roundtrip_test(x)
    def sequence_sum(self, value):
        """Implements `sequence_sum` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        value = value_impl.to_value(value, None, self._context_stack)
        if isinstance(value.type_signature, computation_types.SequenceType):
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        if not type_utils.is_sum_compatible(element_type):
            raise TypeError(
                'The value type {} is not compatible with the sum operator.'.
                format(value.type_signature.member))

        if isinstance(value.type_signature, computation_types.SequenceType):
            value = value_impl.ValueImpl.get_comp(value)
            return building_block_factory.create_sequence_sum(value)
        elif isinstance(value.type_signature, computation_types.FederatedType):
            intrinsic_type = computation_types.FunctionType(
                value.type_signature.member,
                value.type_signature.member.element)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_SUM.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            if value.type_signature.placement is placements.SERVER:
                return self.federated_apply(intrinsic_impl, value)
            elif value.type_signature.placement is placements.CLIENTS:
                return self.federated_map(intrinsic_impl, value)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    value.type_signature.placement))
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_sum()` to a value of type {}.'.
                format(value.type_signature))
Exemple #23
0
 def _normalize_intrinsic_bit(comp):
     """Replaces federated map all equal with federated map."""
     if comp.uri != intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri:
         return comp, False
     parameter_type = [
         comp.type_signature.parameter[0],
         computation_types.FederatedType(
             comp.type_signature.parameter[1].member, placements.CLIENTS)
     ]
     intrinsic_type = computation_types.FunctionType(
         parameter_type,
         computation_types.FederatedType(comp.type_signature.result.member,
                                         placements.CLIENTS))
     new_intrinsic = building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type)
     return new_intrinsic, True
Exemple #24
0
  def test_generic_plus_reduces(self):
    uri = intrinsic_defs.GENERIC_PLUS.uri
    comp = building_blocks.Intrinsic(
        uri, computation_types.FunctionType([tf.float32, tf.float32],
                                            tf.float32))

    count_before_reduction = _count_intrinsics(comp, uri)
    reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
        comp)
    count_after_reduction = _count_intrinsics(reduced, uri)

    self.assertTrue(modified)
    self.assert_types_identical(comp.type_signature, reduced.type_signature)
    self.assertGreater(count_before_reduction, 0)
    self.assertEqual(count_after_reduction, 0)
    tree_analysis.check_contains_only_reducible_intrinsics(reduced)
Exemple #25
0
  def test_generic_divide_reduces(self):
    uri = intrinsic_defs.GENERIC_DIVIDE.uri
    context_stack = context_stack_impl.context_stack
    comp = building_blocks.Intrinsic(
        uri, computation_types.FunctionType([tf.float32, tf.float32],
                                            tf.float32))

    count_before_reduction = _count_intrinsics(comp, uri)
    reduced, modified = value_transformations.replace_all_intrinsics_with_bodies(
        comp, context_stack)
    count_after_reduction = _count_intrinsics(reduced, uri)

    self.assertGreater(count_before_reduction, 0)
    self.assertEqual(count_after_reduction, 0)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(reduced)
    self.assertTrue(modified)
Exemple #26
0
 def __add__(self, other):
     other = to_value(other, None, self._context_stack)
     if not type_utils.are_equivalent_types(self.type_signature,
                                            other.type_signature):
         raise TypeError('Cannot add {} and {}.'.format(
             self.type_signature, other.type_signature))
     return ValueImpl(
         building_blocks.Call(
             building_blocks.Intrinsic(
                 intrinsic_defs.GENERIC_PLUS.uri,
                 computation_types.FunctionType(
                     [self.type_signature, self.type_signature],
                     self.type_signature)),
             ValueImpl.get_comp(
                 to_value([self, other], None, self._context_stack))),
         self._context_stack)
Exemple #27
0
def create_whimsy_called_intrinsic(parameter_name, parameter_type=tf.int32):
    r"""Returns a whimsy called intrinsic.

            Call
           /    \
  intrinsic      Ref(x)

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
    intrinsic_type = computation_types.FunctionType(parameter_type,
                                                    parameter_type)
    intrinsic = building_blocks.Intrinsic('intrinsic', intrinsic_type)
    ref = building_blocks.Reference(parameter_name, parameter_type)
    return building_blocks.Call(intrinsic, ref)
Exemple #28
0
 def __add__(self, other):
     other = to_value(other, None, self._context_stack)
     if not self.type_signature.is_equivalent_to(other.type_signature):
         raise TypeError('Cannot add {} and {}.'.format(
             self.type_signature, other.type_signature))
     # TODO(b/159281959): Follow up and bind a reference here.
     return ValueImpl(
         building_blocks.Call(
             building_blocks.Intrinsic(
                 intrinsic_defs.GENERIC_PLUS.uri,
                 computation_types.FunctionType(
                     [self.type_signature, self.type_signature],
                     self.type_signature)),
             ValueImpl.get_comp(
                 to_value([self, other], None, self._context_stack))),
         self._context_stack)
Exemple #29
0
    def sequence_map(self, fn, arg):
        """Implements `sequence_map` as defined in `api/intrinsics.py`.

    Args:
      fn: As in `api/intrinsics.py`.
      arg: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        fn = value_impl.to_value(fn, None, self._context_stack)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        arg = value_impl.to_value(arg, None, self._context_stack)

        if isinstance(arg.type_signature, computation_types.SequenceType):
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            return building_block_factory.create_sequence_map(fn, arg)
        elif isinstance(arg.type_signature, computation_types.FederatedType):
            parameter_type = computation_types.SequenceType(
                fn.type_signature.parameter)
            result_type = computation_types.SequenceType(
                fn.type_signature.result)
            intrinsic_type = computation_types.FunctionType(
                (fn.type_signature, parameter_type), result_type)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            local_fn = value_utils.get_curried(intrinsic_impl)(fn)

            if arg.type_signature.placement in [
                    placements.SERVER, placements.CLIENTS
            ]:
                return self.federated_map(local_fn, arg)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    arg.type_signature.placement))
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_map()` to a value of type {}.'.
                format(arg.type_signature))
Exemple #30
0
def sequence_sum(value):
    """Computes a sum of elements in a sequence.

  Args:
    value: A value of a TFF type that is either a sequence, or a federated
      sequence.

  Returns:
    The sum of elements in the sequence. If the argument `value` is of a
    federated type, the result is also of a federated type, with the sum
    computed locally and independently at each location (see also a discussion
    on `sequence_map` and `sequence_reduce`).

  Raises:
    TypeError: If the arguments are of wrong or unsupported types.
  """
    value = value_impl.to_value(value, None)
    if value.type_signature.is_sequence():
        element_type = value.type_signature.element
    else:
        py_typecheck.check_type(value.type_signature,
                                computation_types.FederatedType)
        py_typecheck.check_type(value.type_signature.member,
                                computation_types.SequenceType)
        element_type = value.type_signature.member.element
    type_analysis.check_is_sum_compatible(element_type)

    if value.type_signature.is_sequence():
        comp = building_block_factory.create_sequence_sum(value.comp)
        comp = _bind_comp_as_reference(comp)
        return value_impl.Value(comp)
    elif value.type_signature.is_federated():
        intrinsic_type = computation_types.FunctionType(
            value.type_signature.member, value.type_signature.member.element)
        intrinsic = building_blocks.Intrinsic(intrinsic_defs.SEQUENCE_SUM.uri,
                                              intrinsic_type)
        intrinsic_impl = value_impl.Value(intrinsic)
        return federated_map(intrinsic_impl, value)
    else:
        raise TypeError(
            'Cannot apply `tff.sequence_sum()` to a value of type {}.'.format(
                value.type_signature))