def _create_fn_to_append_chain_zipped_values(value):
    r"""Creates a function to append a chain of zipped values.

  Lambda(arg3)
            \
             append([Call,    Sel(1)])
                    /    \            \
        Lambda(arg2)      Sel(0)       Ref(arg3)
                  \             \
                   \             Ref(arg3)
                    \
                     append([Call,    Sel(1)])
                            /    \            \
                Lambda(arg1)      Sel(0)       Ref(arg2)
                            \           \
                             \           Ref(arg2)
                              \
                               Ref(arg1)

  Note that this function will not respect any names it is passed; names
  for tuples will be cached at a higher level than this function and added back
  in a single call to federated map or federated apply.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing at
      least two elements.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    length = len(named_type_signatures)
    if length < 2:
        raise ValueError(
            'Expected a value with at least two elements, received {} elements.'
            .format(named_type_signatures))
    _, first_type_signature = named_type_signatures[0]
    _, second_type_signature = named_type_signatures[1]
    ref_type = computation_types.NamedTupleType((
        first_type_signature.member,
        second_type_signature.member,
    ))
    ref = computation_building_blocks.Reference('arg', ref_type)
    fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref)
    for _, type_signature in named_type_signatures[2:]:
        ref_type = computation_types.NamedTupleType((
            fn.type_signature.parameter,
            type_signature.member,
        ))
        ref = computation_building_blocks.Reference('arg', ref_type)
        sel_0 = computation_building_blocks.Selection(ref, index=0)
        call = computation_building_blocks.Call(fn, sel_0)
        sel_1 = computation_building_blocks.Selection(ref, index=1)
        result = create_computation_appending(call, sel_1)
        fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                                result)
    return fn
示例#2
0
 def test_repr(self):
     self.assertEqual(
         repr(computation_types.NamedTupleType([tf.int32, ('a', tf.bool)])),
         'NamedTupleType([TensorType(tf.int32), (\'a\', TensorType(tf.bool))])'
     )
示例#3
0
    def test_returns_string_for_named_tuple_type_with_no_element(self):
        type_spec = computation_types.NamedTupleType([])

        self.assertEqual(type_spec.compact_representation(), '<>')
        self.assertEqual(type_spec.formatted_representation(), '<>')
def deserialize_type(
        type_proto: Optional[pb.Type]) -> Optional[computation_types.Type]:
    """Deserializes 'type_proto' as a computation_types.Type.

  Note: Currently only deserialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_proto: An instance of pb.Type or None.

  Returns:
    The corresponding instance of computation_types.Type (or None if the
    argument was None).

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which deserialization is not
      implemented.
  """
    # TODO(b/113112885): Implement deserialization of the remaining types.
    if type_proto is None:
        return None
    py_typecheck.check_type(type_proto, pb.Type)
    type_variant = type_proto.WhichOneof('type')
    if type_variant is None:
        return None
    elif type_variant == 'tensor':
        tensor_proto = type_proto.tensor
        return computation_types.TensorType(
            dtype=tf.dtypes.as_dtype(tensor_proto.dtype),
            shape=_to_tensor_shape(tensor_proto))
    elif type_variant == 'sequence':
        return computation_types.SequenceType(
            deserialize_type(type_proto.sequence.element))
    elif type_variant == 'tuple':
        return computation_types.NamedTupleType([
            (lambda k, v: (k, v)
             if k else v)(e.name, deserialize_type(e.value))
            for e in type_proto.tuple.element
        ])
    elif type_variant == 'function':
        return computation_types.FunctionType(
            parameter=deserialize_type(type_proto.function.parameter),
            result=deserialize_type(type_proto.function.result))
    elif type_variant == 'placement':
        return computation_types.PlacementType()
    elif type_variant == 'federated':
        placement_oneof = type_proto.federated.placement.WhichOneof(
            'placement')
        if placement_oneof == 'value':
            return computation_types.FederatedType(
                member=deserialize_type(type_proto.federated.member),
                placement=placement_literals.uri_to_placement_literal(
                    type_proto.federated.placement.value.uri),
                all_equal=type_proto.federated.all_equal)
        else:
            raise NotImplementedError(
                'Deserialization of federated types with placement spec as {} '
                'is not currently implemented yet.'.format(placement_oneof))
    else:
        raise NotImplementedError(
            'Unknown type variant {}.'.format(type_variant))
示例#5
0
 def _concretize_abstract_types(
         abstract_type_spec: computation_types.Type,
         concrete_type_spec: computation_types.Type
 ) -> computation_types.Type:
     """Recursive helper function to construct concrete type spec."""
     if abstract_type_spec.is_abstract():
         bound_type = bound_abstract_types.get(str(
             abstract_type_spec.label))
         if bound_type:
             return bound_type
         else:
             bound_abstract_types[str(
                 abstract_type_spec.label)] = concrete_type_spec
             return concrete_type_spec
     elif abstract_type_spec.is_tensor():
         return abstract_type_spec
     elif abstract_type_spec.is_tuple():
         if not concrete_type_spec.is_tuple():
             raise TypeError(type_error_string)
         abstract_elements = anonymous_tuple.to_elements(abstract_type_spec)
         concrete_elements = anonymous_tuple.to_elements(concrete_type_spec)
         if len(abstract_elements) != len(concrete_elements):
             raise TypeError(type_error_string)
         concretized_tuple_elements = []
         for k in range(len(abstract_elements)):
             if abstract_elements[k][0] != concrete_elements[k][0]:
                 raise TypeError(type_error_string)
             concretized_tuple_elements.append(
                 (abstract_elements[k][0],
                  _concretize_abstract_types(abstract_elements[k][1],
                                             concrete_elements[k][1])))
         return computation_types.NamedTupleType(concretized_tuple_elements)
     elif abstract_type_spec.is_sequence():
         if not concrete_type_spec.is_sequence():
             raise TypeError(type_error_string)
         return computation_types.SequenceType(
             _concretize_abstract_types(abstract_type_spec.element,
                                        concrete_type_spec.element))
     elif abstract_type_spec.is_function():
         if not concrete_type_spec.is_function():
             raise TypeError(type_error_string)
         if abstract_type_spec.parameter is None:
             if concrete_type_spec.parameter is not None:
                 return TypeError(type_error_string)
             concretized_param = None
         else:
             concretized_param = _concretize_abstract_types(
                 abstract_type_spec.parameter, concrete_type_spec.parameter)
         concretized_result = _concretize_abstract_types(
             abstract_type_spec.result, concrete_type_spec.result)
         return computation_types.FunctionType(concretized_param,
                                               concretized_result)
     elif abstract_type_spec.is_placement():
         if not concrete_type_spec.is_placement():
             raise TypeError(type_error_string)
         return abstract_type_spec
     elif abstract_type_spec.is_federated():
         if not concrete_type_spec.is_federated():
             raise TypeError(type_error_string)
         new_member = _concretize_abstract_types(abstract_type_spec.member,
                                                 concrete_type_spec.member)
         return computation_types.FederatedType(
             new_member, abstract_type_spec.placement,
             abstract_type_spec.all_equal)
     else:
         raise TypeError(
             'Unexpected abstract typespec {}.'.format(abstract_type_spec))
示例#6
0
 def test_construct_setattr_named_tuple_type_fails_on_none_value(self):
   good_type = computation_types.NamedTupleType([('a', tf.int32)])
   with self.assertRaises(TypeError):
     _ = computation_constructing_utils.construct_named_tuple_setattr_lambda(
         good_type, 'a', None)
示例#7
0
 def test_list_structures_from_element_type_spec_with_anonymous_tuples(self):
   self._test_list_structure(
       computation_types.NamedTupleType([('a', tf.int32)]), [
           anonymous_tuple.AnonymousTuple([('a', 1)]),
           anonymous_tuple.AnonymousTuple([('a', 2)])
       ], 'OrderedDict([(\'a\', array([1,2],dtype=int32))])')
 def test_with_single_abstract_type_and_tuple_type(self):
     t1 = computation_types.AbstractType('T1')
     t2 = computation_types.NamedTupleType([tf.int32])
     self.assertTrue(type_analysis.is_concrete_instance_of(t2, t1))
 def test_raises_with_different_lengths(self):
     t1 = computation_types.NamedTupleType(
         [computation_types.AbstractType('T1')] * 2)
     t2 = computation_types.NamedTupleType([tf.int32])
     with self.assertRaises(TypeError):
         type_analysis.is_concrete_instance_of(t2, t1)
 def test_returns_false_with_named_tuple_type_spec(self):
     value = anonymous_tuple.AnonymousTuple([('a', 0.0)])
     type_spec = computation_types.NamedTupleType([('a', tf.float32)])
     self.assertFalse(
         type_analysis.is_anon_tuple_with_py_container(value, type_spec))
 def test_raises_different_structures(self):
     with self.assertRaises(TypeError):
         type_analysis.is_concrete_instance_of(
             computation_types.TensorType(tf.int32),
             computation_types.NamedTupleType([tf.int32]))
class IsValidBitwidthTypeForValueType(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters(
        ('single int', computation_types.TensorType(
            tf.int32), computation_types.TensorType(tf.int32)),
        ('tuple', computation_types.NamedTupleType([tf.int32, tf.int32]),
         computation_types.NamedTupleType([tf.int32, tf.int32])),
        ('named tuple', computation_types.NamedTupleType([
            ('x', tf.int32)
        ]), computation_types.NamedTupleType([('x', tf.int32)])),
        ('single int for complex tensor', computation_types.TensorType(
            tf.int32), computation_types.TensorType(tf.int32, [5, 97, 204])),
        ('different kinds of ints', computation_types.TensorType(
            tf.int32), computation_types.TensorType(tf.int8)),
    )
    # pyformat: enable
    def test_returns_true(self, bitwidth_type, value_type):
        self.assertTrue(
            type_analysis.is_valid_bitwidth_type_for_value_type(
                bitwidth_type, value_type))

    # pyformat: disable
    @parameterized.named_parameters(
        ('single int_for_tuple', computation_types.TensorType(
            tf.int32), computation_types.NamedTupleType([tf.int32, tf.int32])),
        ('miscounted tuple',
         computation_types.NamedTupleType([tf.int32, tf.int32, tf.int32]),
         computation_types.NamedTupleType([tf.int32, tf.int32])),
        ('miscounted tuple 2',
         computation_types.NamedTupleType([tf.int32, tf.int32]),
         computation_types.NamedTupleType([tf.int32, tf.int32, tf.int32])),
        ('misnamed tuple', computation_types.NamedTupleType([
            ('x', tf.int32)
        ]), computation_types.NamedTupleType([('y', tf.int32)])),
    )
    # pyformat: enable
    def test_returns_false(self, bitwidth_type, value_type):
        self.assertFalse(
            type_analysis.is_valid_bitwidth_type_for_value_type(
                bitwidth_type, value_type))
class CheckAllAbstractTypesAreBoundTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters([
        ('tensor_type', computation_types.TensorType(tf.int32)),
        ('function_type_with_no_arg',
         computation_types.FunctionType(None, tf.int32)),
        ('function_type_with_int_arg',
         computation_types.FunctionType(tf.int32, tf.int32)),
        ('function_type_with_abstract_arg',
         computation_types.FunctionType(computation_types.AbstractType('T'),
                                        computation_types.AbstractType('T'))),
        ('tuple_tuple_function_type_with_abstract_arg',
         computation_types.NamedTupleType([
             computation_types.NamedTupleType([
                 computation_types.FunctionType(
                     computation_types.AbstractType('T'),
                     computation_types.AbstractType('T')),
             ])
         ])),
        ('function_type_with_unbound_function_arg',
         computation_types.FunctionType(
             computation_types.FunctionType(
                 None, computation_types.AbstractType('T')),
             computation_types.AbstractType('T'))),
        ('function_type_with_sequence_arg',
         computation_types.FunctionType(
             computation_types.SequenceType(
                 computation_types.AbstractType('T')), tf.int32)),
        ('function_type_with_two_abstract_args',
         computation_types.FunctionType(
             computation_types.NamedTupleType([
                 computation_types.AbstractType('T'),
                 computation_types.AbstractType('U'),
             ]),
             computation_types.NamedTupleType([
                 computation_types.AbstractType('T'),
                 computation_types.AbstractType('U'),
             ]))),
    ])
    # pyformat: enable
    def test_does_not_raise_type_error(self, type_spec):
        try:
            type_analysis.check_all_abstract_types_are_bound(type_spec)
        except TypeError:
            self.fail('Raised TypeError unexpectedly.')

    # pyformat: disable
    @parameterized.named_parameters([
        ('abstract_type', computation_types.AbstractType('T')),
        ('function_type_with_no_arg',
         computation_types.FunctionType(None,
                                        computation_types.AbstractType('T'))),
        ('function_type_with_int_arg',
         computation_types.FunctionType(tf.int32,
                                        computation_types.AbstractType('T'))),
        ('function_type_with_abstract_arg',
         computation_types.FunctionType(computation_types.AbstractType('T'),
                                        computation_types.AbstractType('U'))),
    ])
    # pyformat: enable
    def test_raises_type_error(self, type_spec):
        with self.assertRaises(TypeError):
            type_analysis.check_all_abstract_types_are_bound(type_spec)
class CheckWellFormedTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters([
        ('abstract_type', computation_types.AbstractType('T')),
        ('federated_type',
         computation_types.FederatedType(tf.int32,
                                         placement_literals.CLIENTS)),
        ('function_type', computation_types.FunctionType(tf.int32, tf.int32)),
        ('named_tuple_type', computation_types.NamedTupleType([tf.int32] * 3)),
        ('placement_type', computation_types.PlacementType()),
        ('sequence_type', computation_types.SequenceType(tf.int32)),
        ('tensor_type', computation_types.TensorType(tf.int32)),
    ])
    # pyformat: enable
    def test_does_not_raise_type_error(self, type_signature):
        try:
            type_analysis.check_well_formed(type_signature)
        except TypeError:
            self.fail('Raised TypeError unexpectedly.')

    @parameterized.named_parameters([
        ('federated_function_type',
         computation_types.FederatedType(
             computation_types.FunctionType(tf.int32, tf.int32),
             placement_literals.CLIENTS)),
        ('federated_federated_type',
         computation_types.FederatedType(
             computation_types.FederatedType(tf.int32,
                                             placement_literals.CLIENTS),
             placement_literals.CLIENTS)),
        ('sequence_sequence_type',
         computation_types.SequenceType(
             computation_types.SequenceType([tf.int32]))),
        ('sequence_federated_type',
         computation_types.SequenceType(
             computation_types.FederatedType(tf.int32,
                                             placement_literals.CLIENTS))),
        ('tuple_federated_function_type',
         computation_types.NamedTupleType([
             computation_types.FederatedType(
                 computation_types.FunctionType(tf.int32, tf.int32),
                 placement_literals.CLIENTS)
         ])),
        ('tuple_federated_federated_type',
         computation_types.NamedTupleType([
             computation_types.FederatedType(
                 computation_types.FederatedType(tf.int32,
                                                 placement_literals.CLIENTS),
                 placement_literals.CLIENTS)
         ])),
        ('federated_tuple_function_type',
         computation_types.FederatedType(
             computation_types.NamedTupleType(
                 [computation_types.FunctionType(tf.int32, tf.int32)]),
             placement_literals.CLIENTS)),
    ])
    # pyformat: enable
    def test_raises_type_error(self, type_signature):
        with self.assertRaises(TypeError):
            type_analysis.check_well_formed(type_signature)
示例#15
0
def _extract_update(after_aggregate):
    """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    s7_elements_in_after_aggregate_result = [0, 1]
    s7_output_extracted = transformations.select_output_from_lambda(
        after_aggregate, s7_elements_in_after_aggregate_result)
    s7_output_zipped = building_blocks.Lambda(
        s7_output_extracted.parameter_name, s7_output_extracted.parameter_type,
        building_block_factory.create_federated_zip(
            s7_output_extracted.result))
    s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]]
    s6_to_s7_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            s7_output_zipped,
            s6_elements_in_after_aggregate_parameter).result.function)

    # TODO(b/148942011): The transformation
    # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
    # from nested structures, therefore we need to pack the type signature
    # `<s1, s3, s4>` as `<s1, <s3, s4>>`.
    name_generator = building_block_factory.unique_name_generator(
        s6_to_s7_computation)

    pack_ref_name = next(name_generator)
    pack_ref_type = computation_types.NamedTupleType([
        s6_to_s7_computation.parameter_type.member[0],
        computation_types.NamedTupleType([
            s6_to_s7_computation.parameter_type.member[1],
            s6_to_s7_computation.parameter_type.member[2],
        ]),
    ])
    pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type)
    sel_s1 = building_blocks.Selection(pack_ref, index=0)
    sel = building_blocks.Selection(pack_ref, index=1)
    sel_s3 = building_blocks.Selection(sel, index=0)
    sel_s4 = building_blocks.Selection(sel, index=1)
    result = building_blocks.Tuple([sel_s1, sel_s3, sel_s4])
    pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature,
                                     result)
    ref_name = next(name_generator)
    ref_type = computation_types.FederatedType(pack_ref_type,
                                               placements.SERVER)
    ref = building_blocks.Reference(ref_name, ref_type)
    unpacked_args = building_block_factory.create_federated_map_or_apply(
        pack_fn, ref)
    call = building_blocks.Call(s6_to_s7_computation, unpacked_args)
    fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
    return transformations.consolidate_and_extract_local_processing(fn)
 def test_passes_empty_tuples(self):
     self.assertTrue(
         type_analysis.is_binary_op_with_upcast_compatible_pair(
             computation_types.NamedTupleType([]),
             computation_types.NamedTupleType([])))
示例#17
0
 def test_construct_setattr_named_tuple_type_fails_on_none_name(self):
   good_type = computation_types.NamedTupleType([('a', tf.int32)])
   value_comp = computation_building_blocks.Data('x', tf.int32)
   with self.assertRaises(TypeError):
     _ = computation_constructing_utils.construct_named_tuple_setattr_lambda(
         good_type, None, value_comp)
 def test_fails_named_tuple_type_and_non_scalar_tensor(self):
     self.assertFalse(
         type_analysis.is_binary_op_with_upcast_compatible_pair(
             computation_types.NamedTupleType([
                 ('a', computation_types.TensorType(tf.int32, [2, 2]))
             ]), computation_types.TensorType(tf.int32, [2])))
示例#19
0
 def test_list_structures_from_element_type_spec_with_empty_dict_values(self):
   self._test_list_structure(
       computation_types.NamedTupleType([]), [{}, {}, {}], 'OrderedDict()')
示例#20
0
def capture_result_from_graph(result, graph):
    """Captures a result stamped into a tf.Graph as a type signature and binding.

  Args:
    result: The result to capture, a Python object that is composed of tensors,
      possibly nested within Python structures such as dictionaries, lists,
      tuples, or named tuples.
    graph: The instance of tf.Graph to use.

  Returns:
    A tuple (type_spec, binding), where 'type_spec' is an instance of
    computation_types.Type that describes the type of the result, and 'binding'
    is an instance of TensorFlow.Binding that indicates how parts of the result
    type relate to the tensors and ops that appear in the result.

  Raises:
    TypeError: If the argument or any of its parts are of an uexpected type.
  """
    # TODO(b/113112885): The emerging extensions for serializing SavedModels may
    # end up introducing similar concepts of bindings, etc., we should look here
    # into the possibility of reusing some of that code when it's available.
    if isinstance(result, dtype_utils.TENSOR_REPRESENTATION_TYPES):
        with graph.as_default():
            result = tf.constant(result)
    if tf.contrib.framework.is_tensor(result):
        if hasattr(result, 'read_value'):
            # We have a tf.Variable-like result, get a proper tensor to fetch.
            with graph.as_default():
                result = result.read_value()
        return (computation_types.TensorType(result.dtype.base_dtype,
                                             result.shape),
                pb.TensorFlow.Binding(tensor=pb.TensorFlow.TensorBinding(
                    tensor_name=result.name)))
    elif py_typecheck.is_named_tuple(result):
        # Special handling needed for collections.namedtuples since they do not have
        # anything in the way of a shared base class. Note we don't want to rely on
        # the fact that collections.namedtuples inherit from 'tuple' because we'd be
        # failing to retain the information about naming of tuple members.
        # pylint: disable=protected-access
        return capture_result_from_graph(result._asdict(), graph)
        # pylint: enable=protected-access
    elif isinstance(result, (dict, anonymous_tuple.AnonymousTuple)):
        if isinstance(result, anonymous_tuple.AnonymousTuple):
            name_value_pairs = anonymous_tuple.to_elements(result)
        elif isinstance(result, collections.OrderedDict):
            name_value_pairs = six.iteritems(result)
        else:
            name_value_pairs = sorted(six.iteritems(result))
        element_name_type_binding_triples = [
            ((k, ) + capture_result_from_graph(v, graph))
            for k, v in name_value_pairs
        ]
        return (computation_types.NamedTupleType([
            ((e[0], e[1]) if e[0] else e[1])
            for e in element_name_type_binding_triples
        ]),
                pb.TensorFlow.Binding(tuple=pb.TensorFlow.NamedTupleBinding(
                    element=[e[2]
                             for e in element_name_type_binding_triples])))
    elif isinstance(result, (list, tuple)):
        element_type_binding_pairs = [
            capture_result_from_graph(e, graph) for e in result
        ]
        return (computation_types.NamedTupleType(
            [e[0] for e in element_type_binding_pairs]),
                pb.TensorFlow.Binding(tuple=pb.TensorFlow.NamedTupleBinding(
                    element=[e[1] for e in element_type_binding_pairs])))
    elif isinstance(result, DATASET_REPRESENTATION_TYPES):
        element_type = type_utils.tf_dtypes_and_shapes_to_type(
            result.output_types, result.output_shapes)
        handle_name = result.make_one_shot_iterator().string_handle().name
        return (computation_types.SequenceType(element_type),
                pb.TensorFlow.Binding(sequence=pb.TensorFlow.SequenceBinding(
                    iterator_string_handle_name=handle_name)))
    else:
        raise TypeError(
            'Cannot capture a result of an unsupported type {}.'.format(
                py_typecheck.type_string(type(result))))
示例#21
0
 def test_list_structures_from_element_type_spec_with_empty_anon_tuples(self):
   self._test_list_structure(
       computation_types.NamedTupleType([]), [
           anonymous_tuple.AnonymousTuple([]),
           anonymous_tuple.AnonymousTuple([])
       ], 'OrderedDict()')
示例#22
0
def transform_type_postorder(
    type_signature: computation_types.Type,
    transform_fn: Callable[[computation_types.Type],
                           Tuple[computation_types.Type, bool]]):
    """Walks type tree of `type_signature` postorder, calling `transform_fn`.

  Args:
    type_signature: Instance of `computation_types.Type` to transform
      recursively.
    transform_fn: Transformation function to apply to each node in the type tree
      of `type_signature`. Must be instance of Python function type.

  Returns:
    A possibly transformed version of `type_signature`, with each node in its
    tree the result of applying `transform_fn` to the corresponding node in
    `type_signature`.

  Raises:
    TypeError: If the types don't match the specification above.
  """
    py_typecheck.check_type(type_signature, computation_types.Type)
    py_typecheck.check_callable(transform_fn)
    if type_signature.is_federated():
        transformed_member, member_mutated = transform_type_postorder(
            type_signature.member, transform_fn)
        if member_mutated:
            type_signature = computation_types.FederatedType(
                transformed_member, type_signature.placement,
                type_signature.all_equal)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or member_mutated
    elif type_signature.is_sequence():
        transformed_element, element_mutated = transform_type_postorder(
            type_signature.element, transform_fn)
        if element_mutated:
            type_signature = computation_types.SequenceType(
                transformed_element)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or element_mutated
    elif type_signature.is_function():
        if type_signature.parameter is not None:
            transformed_parameter, parameter_mutated = transform_type_postorder(
                type_signature.parameter, transform_fn)
        else:
            transformed_parameter, parameter_mutated = (None, False)
        transformed_result, result_mutated = transform_type_postorder(
            type_signature.result, transform_fn)
        if parameter_mutated or result_mutated:
            type_signature = computation_types.FunctionType(
                transformed_parameter, transformed_result)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, (type_signature_mutated or parameter_mutated
                                or result_mutated)
    elif type_signature.is_tuple():
        elements = []
        elements_mutated = False
        for element in anonymous_tuple.iter_elements(type_signature):
            transformed_element, element_mutated = transform_type_postorder(
                element[1], transform_fn)
            elements_mutated = elements_mutated or element_mutated
            elements.append((element[0], transformed_element))
        if elements_mutated:
            if type_signature.is_tuple_with_py_container():
                type_signature = computation_types.NamedTupleTypeWithPyContainerType(
                    elements, type_signature.python_container)
            else:
                type_signature = computation_types.NamedTupleType(elements)
        type_signature, type_signature_mutated = transform_fn(type_signature)
        return type_signature, type_signature_mutated or elements_mutated
    elif type_signature.is_abstract() or type_signature.is_placement(
    ) or type_signature.is_tensor():
        return transform_fn(type_signature)
示例#23
0
async def compute_intrinsic_federated_weighted_mean(
    executor: executor_base.Executor, arg: executor_value_base.ExecutorValue
) -> executor_value_base.ExecutorValue:
    """Computes a federated weighted mean on the given `executor`.

  Args:
    executor: The executor to use.
    arg: The argument to embedded in `executor`.

  Returns:
    The result embedded in `executor`.
  """
    type_utils.check_valid_federated_weighted_mean_argument_tuple_type(
        arg.type_signature)
    zip1_type = computation_types.FunctionType(
        computation_types.NamedTupleType([
            type_factory.at_clients(arg.type_signature[0].member),
            type_factory.at_clients(arg.type_signature[1].member)
        ]),
        type_factory.at_clients(
            computation_types.NamedTupleType(
                [arg.type_signature[0].member, arg.type_signature[1].member])))
    zip1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS,
                                      zip1_type)
    zipped_arg = await executor.create_call(
        await executor.create_value(zip1_comp, zip1_type), arg)

    # TODO(b/134543154): Replace with something that produces a section of
    # plain TensorFlow code instead of constructing a lambda (so that this
    # can be executed directly on top of a plain TensorFlow-based executor).
    multiply_blk = building_block_factory.create_binary_operator_with_upcast(
        zipped_arg.type_signature.member, tf.multiply)

    map_type = computation_types.FunctionType(
        computation_types.NamedTupleType(
            [multiply_blk.type_signature, zipped_arg.type_signature]),
        type_factory.at_clients(multiply_blk.type_signature.result))
    map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP, map_type)
    products = await executor.create_call(
        await executor.create_value(map_comp, map_type), await
        executor.create_tuple([
            await executor.create_value(multiply_blk.proto,
                                        multiply_blk.type_signature),
            zipped_arg
        ]))
    sum1_type = computation_types.FunctionType(
        type_factory.at_clients(products.type_signature.member),
        type_factory.at_server(products.type_signature.member))
    sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum1_type)
    sum_of_products = await executor.create_call(
        await executor.create_value(sum1_comp, sum1_type), products)
    sum2_type = computation_types.FunctionType(
        type_factory.at_clients(arg.type_signature[1].member),
        type_factory.at_server(arg.type_signature[1].member))
    sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum2_type)
    total_weight = await executor.create_call(
        *(await asyncio.gather(executor.create_value(sum2_comp, sum2_type),
                               executor.create_selection(arg, index=1))))
    zip2_type = computation_types.FunctionType(
        computation_types.NamedTupleType(
            [sum_of_products.type_signature, total_weight.type_signature]),
        type_factory.at_server(
            computation_types.NamedTupleType([
                sum_of_products.type_signature.member,
                total_weight.type_signature.member
            ])))
    zip2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_ZIP_AT_SERVER,
                                      zip2_type)
    divide_arg = await executor.create_call(*(await asyncio.gather(
        executor.create_value(zip2_comp, zip2_type),
        executor.create_tuple([sum_of_products, total_weight]))))
    divide_blk = building_block_factory.create_binary_operator_with_upcast(
        divide_arg.type_signature.member, tf.divide)
    apply_type = computation_types.FunctionType(
        computation_types.NamedTupleType(
            [divide_blk.type_signature, divide_arg.type_signature]),
        type_factory.at_server(divide_blk.type_signature.result))
    apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY,
                                       apply_type)
    return await executor.create_call(*(await asyncio.gather(
        executor.create_value(apply_comp, apply_type),
        executor.create_tuple([
            await executor.create_value(divide_blk.proto,
                                        divide_blk.type_signature), divide_arg
        ]))))
示例#24
0
    def federated_aggregate(self, value, zero, accumulate, merge, report):
        """Implements `federated_aggregate` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value,
                                                   placement_literals.CLIENTS,
                                                   'value to be aggregated')

        zero = value_impl.to_value(zero, None, self._context_stack)
        py_typecheck.check_type(zero, value_base.Value)
        accumulate = value_impl.to_value(
            accumulate,
            None,
            self._context_stack,
            parameter_type_hint=computation_types.NamedTupleType(
                [zero.type_signature, value.type_signature.member]))
        merge = value_impl.to_value(
            merge,
            None,
            self._context_stack,
            parameter_type_hint=computation_types.NamedTupleType(
                [zero.type_signature, zero.type_signature]))
        report = value_impl.to_value(report,
                                     None,
                                     self._context_stack,
                                     parameter_type_hint=zero.type_signature)
        for op in [accumulate, merge, report]:
            py_typecheck.check_type(op, value_base.Value)
            py_typecheck.check_type(op.type_signature,
                                    computation_types.FunctionType)

        if not accumulate.type_signature.parameter[0].is_assignable_from(
                zero.type_signature):
            raise TypeError('Expected `zero` to be assignable to type {}, '
                            'but was of incompatible type {}.'.format(
                                accumulate.type_signature.parameter[0],
                                zero.type_signature))

        accumulate_type_expected = type_factory.reduction_op(
            accumulate.type_signature.result, value.type_signature.member)
        merge_type_expected = type_factory.reduction_op(
            accumulate.type_signature.result, accumulate.type_signature.result)
        report_type_expected = computation_types.FunctionType(
            merge.type_signature.result, report.type_signature.result)
        for op_name, op, type_expected in [
            ('accumulate', accumulate, accumulate_type_expected),
            ('merge', merge, merge_type_expected),
            ('report', report, report_type_expected)
        ]:
            if not type_expected.is_assignable_from(op.type_signature):
                raise TypeError(
                    'Expected parameter `{}` to be of type {}, but received {} instead.'
                    .format(op_name, type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        accumulate = value_impl.ValueImpl.get_comp(accumulate)
        merge = value_impl.ValueImpl.get_comp(merge)
        report = value_impl.ValueImpl.get_comp(report)

        comp = building_block_factory.create_federated_aggregate(
            value, zero, accumulate, merge, report)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
示例#25
0
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for testing executors."""

import collections

import numpy as np
import tensorflow as tf

from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import computations
from tensorflow_federated.python.core.impl import set_default_executor
from tensorflow_federated.python.core.utils import tf_computation_utils

_mnist_model_type = computation_types.NamedTupleType([
    ('weights', computation_types.TensorType(tf.float32, [784, 10])),
    ('bias', computation_types.TensorType(tf.float32, [10]))
])

_mnist_batch_type = computation_types.NamedTupleType([
    ('x', computation_types.TensorType(tf.float32, [None, 784])),
    ('y', computation_types.TensorType(tf.int32, [None]))
])

_mnist_sample_batch = collections.OrderedDict([
    ('x', np.ones([1, 784], dtype=np.float32)),
    ('y', np.ones([1], dtype=np.int32))
])

_mnist_initial_model = collections.OrderedDict([
    ('weights', np.zeros([784, 10], dtype=np.float32)),
    ('bias', np.zeros([10], dtype=np.float32))
示例#26
0
def _create_before_and_after_broadcast_for_no_broadcast(tree):
    r"""Creates a before and after broadcast computations for the given `tree`.

  This function returns the two ASTs:

  Lambda
  |
  Tuple
  |
  []

       Lambda(x)
       |
       Call
      /    \
  Comp      Sel(0)
           /
     Ref(x)

  The first AST is an empty structure that has a type signature satisfying the
  requirements of before broadcast.

  In the second AST, `Comp` is `tree`; `Lambda` has a type signature satisfying
  the requirements of after broadcast; and the argument passed to `Comp` is a
  selection from the parameter of `Lambda` which intentionally drops `c2` on the
  floor.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_BROADCAST` in `tree`. As a result, this function
  does not assert that there is no `intrinsic_defs.FEDERATED_BROADCAST` in
  `tree` and it does not assert that `tree` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
    name_generator = building_block_factory.unique_name_generator(tree)

    parameter_name = next(name_generator)
    empty_tuple = building_blocks.Tuple([])
    value = building_block_factory.create_federated_value(
        empty_tuple, placements.SERVER)
    before_broadcast = building_blocks.Lambda(parameter_name,
                                              tree.type_signature.parameter,
                                              value)

    parameter_name = next(name_generator)
    type_signature = computation_types.FederatedType(
        before_broadcast.type_signature.result.member, placements.CLIENTS)
    parameter_type = computation_types.NamedTupleType(
        [tree.type_signature.parameter, type_signature])
    ref = building_blocks.Reference(parameter_name, parameter_type)
    arg = building_blocks.Selection(ref, index=0)
    call = building_blocks.Call(tree, arg)
    after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature,
                                             call)

    return before_broadcast, after_broadcast
示例#27
0
 def test_with_none_keys(self):
     self.assertEqual(
         str(computation_types.NamedTupleType([(None, tf.int32)])),
         '<int32>')
示例#28
0
def _create_before_and_after_aggregate_for_no_federated_secure_sum(tree):
    r"""Creates a before and after aggregate computations for the given `tree`.

  Lambda
  |
  Tuple
  |
  [Comp, Tuple]
         |
         [Tuple, []]
          |
          []

       Lambda(x)
       |
       Call
      /    \
  Comp      Tuple
            |
            [Sel(0),      Sel(0)]
            /            /
         Ref(x)    Sel(1)
                  /
            Ref(x)

  In the first AST, the first element returned by `Lambda`, `Comp`, is the
  result of the before aggregate returned by force aligning and splitting `tree`
  by `intrinsic_defs.FEDERATED_AGGREGATE.uri` and the second element returned by
  `Lambda` is an empty structure that represents the argument to the secure sum
  intrinsic. Therefore, the first AST has a type signature satisfying the
  requirements of before aggregate.

  In the second AST, `Comp` is the after aggregate returned by force aligning
  and splitting `tree` by intrinsic_defs.FEDERATED_AGGREGATE.uri; `Lambda` has a
  type signature satisfying the requirements of after aggregate; and the
  argument passed to `Comp` is a selection from the parameter of `Lambda` which
  intentionally drops `s4` on the floor.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create before and after
  broadcast computations for the given `tree` when there is no
  `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree`. As a result, this function
  does not assert that there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in
  `tree` and it does not assert that `tree` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A pair of the form `(before, after)`, where each of `before` and `after`
    is a `tff_framework.ComputationBuildingBlock` that represents a part of the
    result as specified by
    `transformations.force_align_and_split_by_intrinsics`.
  """
    name_generator = building_block_factory.unique_name_generator(tree)

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    empty_tuple = building_blocks.Tuple([])
    value = building_block_factory.create_federated_value(
        empty_tuple, placements.CLIENTS)
    bitwidth = empty_tuple
    args = building_blocks.Tuple([value, bitwidth])
    result = building_blocks.Tuple([before_aggregate.result, args])
    before_aggregate = building_blocks.Lambda(before_aggregate.parameter_name,
                                              before_aggregate.parameter_type,
                                              result)

    ref_name = next(name_generator)
    s4_type = computation_types.FederatedType([], placements.SERVER)
    ref_type = computation_types.NamedTupleType([
        after_aggregate.parameter_type[0],
        computation_types.NamedTupleType([
            after_aggregate.parameter_type[1],
            s4_type,
        ]),
    ])
    ref = building_blocks.Reference(ref_name, ref_type)
    sel_arg = building_blocks.Selection(ref, index=0)
    sel = building_blocks.Selection(ref, index=1)
    sel_s3 = building_blocks.Selection(sel, index=0)
    arg = building_blocks.Tuple([sel_arg, sel_s3])
    call = building_blocks.Call(after_aggregate, arg)
    after_aggregate = building_blocks.Lambda(ref.name, ref.type_signature,
                                             call)

    return before_aggregate, after_aggregate
示例#29
0
 def test_with_unnamed_element(self):
     type_signature = computation_types.NamedTupleType([tf.int32])
     tensor_specs = type_conversions.type_to_tf_tensor_specs(type_signature)
     test.assert_nested_struct_eq(tensor_specs,
                                  (tf.TensorSpec([], tf.int32), ))
示例#30
0
def infer_type(arg):
  """Infers the TFF type of the argument (a `computation_types.Type` instance).

  WARNING: This function is only partially implemented.

  The kinds of arguments that are currently correctly recognized:
  - tensors, variables, and data sets,
  - things that are convertible to tensors (including numpy arrays, builtin
    types, as well as lists and tuples of any of the above, etc.),
  - nested lists, tuples, namedtuples, anonymous tuples, dict, and OrderedDicts.

  Args:
    arg: The argument, the TFF type of which to infer.

  Returns:
    Either an instance of `computation_types.Type`, or `None` if the argument is
    `None`.
  """
  # TODO(b/113112885): Implement the remaining cases here on the need basis.
  if arg is None:
    return None
  elif isinstance(arg, typed_object.TypedObject):
    return arg.type_signature
  elif tf.is_tensor(arg):
    return computation_types.TensorType(arg.dtype.base_dtype, arg.shape)
  elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES):
    return computation_types.SequenceType(
        computation_types.to_type(arg.element_spec))
  elif isinstance(arg, anonymous_tuple.AnonymousTuple):
    return computation_types.NamedTupleType([
        (k, infer_type(v)) if k else infer_type(v)
        for k, v in anonymous_tuple.iter_elements(arg)
    ])
  elif py_typecheck.is_attrs(arg):
    items = attr.asdict(
        arg, dict_factory=collections.OrderedDict, recurse=False)
    return computation_types.NamedTupleTypeWithPyContainerType(
        [(k, infer_type(v)) for k, v in items.items()], type(arg))
  elif py_typecheck.is_named_tuple(arg):
    items = arg._asdict()
    return computation_types.NamedTupleTypeWithPyContainerType(
        [(k, infer_type(v)) for k, v in items.items()], type(arg))
  elif isinstance(arg, dict):
    if isinstance(arg, collections.OrderedDict):
      items = arg.items()
    else:
      items = sorted(arg.items())
    return computation_types.NamedTupleTypeWithPyContainerType(
        [(k, infer_type(v)) for k, v in items], type(arg))
  elif isinstance(arg, (tuple, list)):
    elements = []
    all_elements_named = True
    for element in arg:
      all_elements_named &= py_typecheck.is_name_value_pair(element)
      elements.append(infer_type(element))
    # If this is a tuple of (name, value) pairs, the caller most likely intended
    # this to be a NamedTupleType, so we avoid storing the Python container.
    if all_elements_named:
      return computation_types.NamedTupleType(elements)
    else:
      return computation_types.NamedTupleTypeWithPyContainerType(
          elements, type(arg))
  elif isinstance(arg, str):
    return computation_types.TensorType(tf.string)
  elif isinstance(arg, (np.generic, np.ndarray)):
    return computation_types.TensorType(
        tf.dtypes.as_dtype(arg.dtype), arg.shape)
  else:
    dtype = {bool: tf.bool, int: tf.int32, float: tf.float32}.get(type(arg))
    if dtype:
      return computation_types.TensorType(dtype)
    else:
      # Now fall back onto the heavier-weight processing, as all else failed.
      # Use make_tensor_proto() to make sure to handle it consistently with
      # how TensorFlow is handling values (e.g., recognizing int as int32, as
      # opposed to int64 as in NumPy).
      try:
        # TODO(b/113112885): Find something more lightweight we could use here.
        tensor_proto = tf.make_tensor_proto(arg)
        return computation_types.TensorType(
            tf.dtypes.as_dtype(tensor_proto.dtype),
            tf.TensorShape(tensor_proto.tensor_shape))
      except TypeError as err:
        raise TypeError('Could not infer the TFF type of {}: {}'.format(
            py_typecheck.type_string(type(arg)), err))