Exemplo n.º 1
0
    def test_federated_evaluation(self):
        evaluate = federated_evaluation.build_federated_evaluation(TestModel)
        model_weights_type = model_utils.weights_type_from_model(TestModel)
        type_test_utils.assert_types_equivalent(
            evaluate.type_signature,
            FunctionType(
                parameter=StructType([
                    ('server_model_weights',
                     computation_types.at_server(model_weights_type)),
                    ('federated_dataset',
                     computation_types.at_clients(
                         SequenceType(
                             StructType([('temp',
                                          TensorType(dtype=tf.float32,
                                                     shape=[None]))])))),
                ]),
                result=computation_types.at_server(
                    collections.OrderedDict(eval=collections.OrderedDict(
                        num_over=tf.float32)))))

        def _temp_dict(temps):
            return {'temp': np.array(temps, dtype=np.float32)}

        result = evaluate(
            collections.OrderedDict(trainable=[5.0], non_trainable=[]), [
                [_temp_dict([1.0, 10.0, 2.0, 7.0]),
                 _temp_dict([6.0, 11.0])],
                [_temp_dict([9.0, 12.0, 13.0])],
                [_temp_dict([1.0]),
                 _temp_dict([22.0, 23.0])],
            ])
        self.assertEqual(
            result,
            collections.OrderedDict(
                eval=collections.OrderedDict(num_over=9.0), ))
Exemplo n.º 2
0
    def test_multiple_nested_named_element_selection(self):
        fed_at_clients = computation_types.FederatedType(
            tf.int32, placements.CLIENTS)
        fed_at_server = computation_types.FederatedType(
            tf.int32, placements.SERVER)
        tuple_of_federated_types = computation_types.StructType([
            ('a', [('a', fed_at_clients)]), ('b', fed_at_server),
            ('c', [('c', fed_at_clients)])
        ])
        first_selection = building_blocks.Selection(building_blocks.Selection(
            building_blocks.Reference('x',
                                      tuple_of_federated_types), name='a'),
                                                    name='a')
        second_selection = building_blocks.Selection(building_blocks.Selection(
            building_blocks.Reference('x',
                                      tuple_of_federated_types), name='c'),
                                                     name='c')
        lam = building_blocks.Lambda(
            'x', tuple_of_federated_types,
            building_blocks.Struct([first_selection, second_selection]))

        new_lam = form_utils._as_function_of_some_federated_subparameters(
            lam, [(0, 0), (2, 0)])

        expected_parameter_type = computation_types.at_clients(
            (tf.int32, tf.int32))
        type_test_utils.assert_types_equivalent(
            new_lam.type_signature,
            computation_types.FunctionType(expected_parameter_type,
                                           lam.result.type_signature))
Exemplo n.º 3
0
 def test_create_selection(self):
     executor = executor_bindings.create_reference_resolving_executor(
         executor_bindings.create_tensorflow_executor())
     expected_type_spec = TensorType(shape=[3], dtype=tf.int64)
     value_pb, _ = value_serialization.serialize_value(
         tf.constant([1, 2, 3]), expected_type_spec)
     value = executor.create_value(value_pb)
     self.assertEqual(value.ref, 0)
     # 1. Create a struct from duplicated values.
     struct_value = executor.create_struct([value.ref, value.ref])
     self.assertEqual(struct_value.ref, 1)
     materialized_value = executor.materialize(struct_value.ref)
     deserialized_value, type_spec = value_serialization.deserialize_value(
         materialized_value)
     struct_type_spec = computation_types.to_type(
         [expected_type_spec, expected_type_spec])
     type_test_utils.assert_types_equivalent(type_spec, struct_type_spec)
     deserialized_value = type_conversions.type_to_py_container(
         deserialized_value, struct_type_spec)
     self.assertAllClose([(1, 2, 3), (1, 2, 3)], deserialized_value)
     # 2. Select the first value out of the struct.
     new_value = executor.create_selection(struct_value.ref, 0)
     materialized_value = executor.materialize(new_value.ref)
     deserialized_value, type_spec = value_serialization.deserialize_value(
         materialized_value)
     type_test_utils.assert_types_equivalent(type_spec, expected_type_spec)
     deserialized_value = type_conversions.type_to_py_container(
         deserialized_value, struct_type_spec)
     self.assertAllClose((1, 2, 3), deserialized_value)
Exemplo n.º 4
0
  def test_type_properties(self, value_type):
    factory = stochastic_discretization.StochasticDiscretizationFactory(
        step_size=0.1,
        inner_agg_factory=_measurement_aggregator,
        distortion_aggregation_factory=mean.UnweightedMeanFactory())
    value_type = computation_types.to_type(value_type)
    quantize_type = type_conversions.structure_from_tensor_type_tree(
        lambda x: (tf.int32, x.shape), value_type)
    process = factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.StructType([('step_size', tf.float32),
                                                      ('inner_agg_process', ())
                                                     ])
    server_state_type = computation_types.at_server(server_state_type)
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    type_test_utils.assert_types_equivalent(process.initialize.type_signature,
                                            expected_initialize_type)

    expected_measurements_type = computation_types.StructType([
        ('stochastic_discretization', quantize_type), ('distortion', tf.float32)
    ])
    expected_measurements_type = computation_types.at_server(
        expected_measurements_type)
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    type_test_utils.assert_types_equivalent(process.next.type_signature,
                                            expected_next_type)
Exemplo n.º 5
0
  def test_type_properties(self, name, value_type):
    factory = _hadamard_sum() if name == 'hd' else _dft_sum()
    value_type = computation_types.to_type(value_type)
    process = factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.at_server(
        ((), rotation.SEED_TFF_TYPE))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    type_test_utils.assert_types_equivalent(process.initialize.type_signature,
                                            expected_initialize_type)

    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict([(name, ())]))
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    type_test_utils.assert_types_equivalent(process.next.type_signature,
                                            expected_next_type)
Exemplo n.º 6
0
    def test_type_properties(self, value_type):
        factory = _discretization_sum()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(scale_factor=tf.float32,
                                    prior_norm_bound=tf.float32,
                                    inner_agg_process=()))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(discretize=()))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
Exemplo n.º 7
0
 def test_federated_secure_modular_sum(self, value_dtype, modulus_type):
     uri = intrinsic_defs.FEDERATED_SECURE_MODULAR_SUM.uri
     comp = building_blocks.Intrinsic(
         uri,
         computation_types.FunctionType(
             parameter=[
                 computation_types.at_clients(value_dtype),
                 computation_types.to_type(modulus_type)
             ],
             result=computation_types.at_server(value_dtype)))
     # First without secure intrinsics shouldn't modify anything.
     reduced, modified = tree_transformations.replace_intrinsics_with_bodies(
         comp)
     self.assertFalse(modified)
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     type_test_utils.assert_types_identical(comp.type_signature,
                                            reduced.type_signature)
     # Now replace bodies including secure intrinsics.
     reduced, modified = tree_transformations.replace_secure_intrinsics_with_insecure_bodies(
         comp)
     self.assertTrue(modified)
     # Inserting tensorflow, as we do here, does not preserve python containers
     # currently.
     type_test_utils.assert_types_equivalent(comp.type_signature,
                                             reduced.type_signature)
     self.assertGreater(
         _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SUM.uri), 0)
Exemplo n.º 8
0
    def test_concat_type_properties_unweighted(self, value_type):
        factory = _concat_sum()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # Inner SumFactory has no state.
        server_state_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        # Inner SumFactory has no measurements.
        expected_measurements_type = computation_types.at_server(())
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
Exemplo n.º 9
0
    def test_clip_type_properties_weighted(self, value_type, weight_type):
        factory = _concat_mean()
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.to_type(weight_type)
        process = factory.create(value_type, weight_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # State comes from the inner MeanFactory.
        server_state_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        # Measurements come from the inner mean factory.
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(mean_value=(), mean_weight=()))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type),
                weight=computation_types.at_clients(weight_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
Exemplo n.º 10
0
    def test_local_evaluation(self):
        model_weights_type = model_utils.weights_type_from_model(TestModel)
        batch_type = computation_types.to_type(TestModel().input_spec)
        client_evaluate = federated_evaluation.build_local_evaluation(
            TestModel, model_weights_type, batch_type)
        type_test_utils.assert_types_equivalent(
            client_evaluate.type_signature,
            FunctionType(
                parameter=StructType([
                    ('incoming_model_weights', model_weights_type),
                    ('dataset',
                     SequenceType(
                         StructType([('temp',
                                      TensorType(dtype=tf.float32,
                                                 shape=[None]))]))),
                ]),
                result=collections.OrderedDict(
                    local_outputs=collections.OrderedDict(num_over=tf.float32),
                    num_examples=tf.int64)))

        def _temp_dict(temps):
            return {'temp': np.array(temps, dtype=np.float32)}

        client_result = client_evaluate(
            collections.OrderedDict(trainable=[5.0], non_trainable=[]),
            [_temp_dict([1.0, 10.0, 2.0, 8.0]),
             _temp_dict([6.0, 11.0])])
        self.assertEqual(
            client_result,
            collections.OrderedDict(
                local_outputs=collections.OrderedDict(num_over=4.0),
                num_examples=6))
 def test_serialize_deserialize_nested_tuple_value_without_names(self):
   x = (10, 20)
   x_type = computation_types.to_type((tf.int32, tf.int32))
   value_proto, value_type = value_serialization.serialize_value(x, x_type)
   type_test_utils.assert_types_identical(value_type, x_type)
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_equivalent(type_spec, x_type)
   self.assertEqual(y, structure.from_container((10, 20)))
Exemplo n.º 12
0
 def assert_selected_param_to_result_type(self, old_lam, new_lam, index):
     old_type = old_lam.type_signature
     new_type = new_lam.type_signature
     old_type.check_function()
     new_type.check_function()
     type_test_utils.assert_types_equivalent(
         new_type,
         computation_types.FunctionType(old_type.parameter[index],
                                        old_type.result))
Exemplo n.º 13
0
    def test_construction(self, weighted):
        aggregation_factory = (mean.MeanFactory()
                               if weighted else sum_factory.SumFactory())
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            model_update_aggregation_factory=aggregation_factory)

        if weighted:
            aggregate_state = collections.OrderedDict(value_sum_process=(),
                                                      weight_sum_process=())
            aggregate_metrics = collections.OrderedDict(mean_value=(),
                                                        mean_weight=())
        else:
            aggregate_state = ()
            aggregate_metrics = ()

        server_state_type = computation_types.FederatedType(
            optimizer_utils.ServerState(model=model_utils.ModelWeights(
                trainable=[
                    computation_types.TensorType(tf.float32, [2, 1]),
                    computation_types.TensorType(tf.float32)
                ],
                non_trainable=[computation_types.TensorType(tf.float32)]),
                                        optimizer_state=[tf.int64],
                                        delta_aggregate_state=aggregate_state,
                                        model_broadcast_state=()),
            placements.SERVER)
        type_test_utils.assert_types_equivalent(
            computation_types.FunctionType(parameter=None,
                                           result=server_state_type),
            iterative_process.initialize.type_signature)

        dataset_type = computation_types.FederatedType(
            computation_types.SequenceType(
                collections.OrderedDict(
                    x=computation_types.TensorType(tf.float32, [None, 2]),
                    y=computation_types.TensorType(tf.float32, [None, 1]))),
            placements.CLIENTS)
        metrics_type = computation_types.FederatedType(
            collections.OrderedDict(
                broadcast=(),
                aggregation=aggregate_metrics,
                train=collections.OrderedDict(
                    loss=computation_types.TensorType(tf.float32),
                    num_examples=computation_types.TensorType(tf.int32))),
            placements.SERVER)
        type_test_utils.assert_types_equivalent(
            computation_types.FunctionType(parameter=collections.OrderedDict(
                server_state=server_state_type,
                federated_dataset=dataset_type,
            ),
                                           result=(server_state_type,
                                                   metrics_type)),
            iterative_process.next.type_signature)
Exemplo n.º 14
0
 def assert_compiles_to_tensorflow(
         self, comp: building_blocks.ComputationBuildingBlock):
     result = compiler.compile_local_computation_to_tensorflow(comp)
     if comp.type_signature.is_function():
         result.check_compiled_computation()
     else:
         result.check_call()
         result.function.check_compiled_computation()
     type_test_utils.assert_types_equivalent(comp.type_signature,
                                             result.type_signature)
Exemplo n.º 15
0
 def assert_serializes(self, fn, parameter_type, expected_fn_type_str):
   serializer = tensorflow_serialization.tf_computation_serializer(
       parameter_type, context_stack_impl.context_stack)
   arg_to_fn = next(serializer)
   result = fn(arg_to_fn)
   comp, extra_type_spec = serializer.send(result)
   deserialized_type = type_serialization.deserialize_type(comp.type)
   type_test_utils.assert_types_equivalent(deserialized_type, extra_type_spec)
   self.assertEqual(deserialized_type.compact_representation(),
                    expected_fn_type_str)
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   return comp.tensorflow, extra_type_spec
 def test_serialize_deserialize_sequence_of_ragged_tensors(self, dataset_fn):
   self.skipTest('b/235492749')
   ds = tf.data.Dataset.from_tensor_slices(tf.strings.split(['a b c', 'd e']))
   ds_repr = dataset_fn(ds)
   value_proto, value_type = value_serialization.serialize_value(
       ds_repr, computation_types.SequenceType(element=ds.element_spec))
   expected_type = computation_types.SequenceType(ds.element_spec)
   type_test_utils.assert_types_identical(value_type, expected_type)
   _, type_spec = value_serialization.deserialize_value(value_proto)
   # Only checking for equivalence, we don't have the Python container
   # after deserialization.
   type_test_utils.assert_types_equivalent(type_spec, expected_type)
 def test_serialize_deserialize_nested_tuple_value_with_names(self):
   x = collections.OrderedDict(
       a=10, b=[20, 30], c=collections.OrderedDict(d=40))
   x_type = computation_types.to_type(
       collections.OrderedDict(
           a=tf.int32,
           b=[tf.int32, tf.int32],
           c=collections.OrderedDict(d=tf.int32)))
   value_proto, value_type = value_serialization.serialize_value(x, x_type)
   type_test_utils.assert_types_identical(value_type, x_type)
   y, type_spec = value_serialization.deserialize_value(value_proto)
   # Don't assert on the Python container since it is lost in serialization.
   type_test_utils.assert_types_equivalent(type_spec, x_type)
   self.assertEqual(y, structure.from_container(x, recursive=True))
Exemplo n.º 18
0
    def assert_splits_on(self, comp, calls):
        """Asserts that `force_align_and_split_by_intrinsics` removes intrinsics."""
        if not isinstance(calls, list):
            calls = [calls]
        uris = [call.function.uri for call in calls]
        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, calls)

        # Ensure that the resulting computations no longer contain the split
        # intrinsics.
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uris))
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uris))
        # Removal isn't interesting to test for if it wasn't there to begin with.
        self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uris))

        if comp.parameter_type is not None:
            type_test_utils.assert_types_equivalent(comp.parameter_type,
                                                    before.parameter_type)
        else:
            self.assertIsNone(before.parameter_type)
        # THere must be one parameter for each intrinsic in `calls`.
        before.type_signature.result.check_struct()
        self.assertLen(before.type_signature.result, len(calls))

        # Check that `after`'s parameter is a structure like:
        # {
        #   'original_arg': comp.parameter_type, (if present)
        #   'intrinsic_results': [...],
        # }
        after.parameter_type.check_struct()
        if comp.parameter_type is not None:
            self.assertLen(after.parameter_type, 2)
            type_test_utils.assert_types_equivalent(
                comp.parameter_type, after.parameter_type.original_arg)
        else:
            self.assertLen(after.parameter_type, 1)
        # There must be one result for each intrinsic in `calls`.
        self.assertLen(after.parameter_type.intrinsic_results, len(calls))

        # Check that each pair of (param, result) is a valid type substitution
        # for the intrinsic in question.
        for i in range(len(calls)):
            concrete_signature = computation_types.FunctionType(
                before.type_signature.result[i],
                after.parameter_type.intrinsic_results[i])
            abstract_signature = calls[i].function.intrinsic_def(
            ).type_signature
            type_analysis.check_concrete_instance_of(concrete_signature,
                                                     abstract_signature)
 def test_serialize_deserialize_sequence_of_tuples(self, dataset_fn):
   ds = tf.data.Dataset.range(5).map(
       lambda x: (x * 2, tf.cast(x, tf.int32), tf.cast(x - 1, tf.float32)))
   ds_repr = dataset_fn(ds)
   value_proto, value_type = value_serialization.serialize_value(
       ds_repr,
       computation_types.SequenceType(
           element=(tf.int64, tf.int32, tf.float32)))
   expected_type = computation_types.SequenceType(
       (tf.int64, tf.int32, tf.float32))
   type_test_utils.assert_types_identical(value_type, expected_type)
   y, type_spec = value_serialization.deserialize_value(value_proto)
   # Only checking for equivalence, we don't have the Python container
   # after deserialization.
   type_test_utils.assert_types_equivalent(type_spec, expected_type)
   self.assertAllEqual(list(y), [(x * 2, x, x - 1.) for x in range(5)])
Exemplo n.º 20
0
    def test_unflatten_tf_function(self, result, result_type_spec,
                                   python_container_hint,
                                   expected_python_container):
        type_spec_var = tf.Variable(
            type_serialization.serialize_type(
                result_type_spec).SerializeToString(deterministic=True))

        @tf.function
        def fn():
            return tf.nest.flatten(result)

        packed_fn = serialization._unflatten_fn(fn, type_spec_var,
                                                python_container_hint)
        actual_output = packed_fn()
        type_test_utils.assert_types_equivalent(
            type_conversions.type_from_tensors(actual_output),
            result_type_spec)
        self.assertIsInstance(actual_output, expected_python_container)
  def test_serialize_deserialize_sequence_of_nested_structures(
      self, dataset_fn):

    def _make_nested_tf_structure(x):
      return collections.OrderedDict(
          b=tf.cast(x, tf.int32),
          a=tuple([
              x,
              collections.OrderedDict(u=x * 2, v=x * 3),
              collections.OrderedDict(x=x**2, y=x**3)
          ]))

    ds = tf.data.Dataset.range(5).map(_make_nested_tf_structure)
    ds_repr = dataset_fn(ds)
    element_type = computation_types.to_type(
        collections.OrderedDict(
            b=tf.int32,
            a=tuple([
                tf.int64,
                collections.OrderedDict(u=tf.int64, v=tf.int64),
                collections.OrderedDict(x=tf.int64, y=tf.int64),
            ])))
    sequence_type = computation_types.SequenceType(element=element_type)
    value_proto, value_type = value_serialization.serialize_value(
        ds_repr, sequence_type)
    type_test_utils.assert_types_identical(value_type, sequence_type)
    y, type_spec = value_serialization.deserialize_value(value_proto)
    # These aren't the same because ser/de destroys the PyContainer
    type_test_utils.assert_types_equivalent(type_spec, sequence_type)

    def _build_expected_structure(x):
      return collections.OrderedDict(
          b=x,
          a=tuple([
              x,
              collections.OrderedDict(u=x * 2, v=x * 3),
              collections.OrderedDict(x=x**2, y=x**3)
          ]))

    actual_values = list(y)
    expected_values = [_build_expected_structure(x) for x in range(5)]
    for actual, expected in zip(actual_values, expected_values):
      self.assertEqual(type(actual), type(expected))
      self.assertAllClose(actual, expected)
Exemplo n.º 22
0
    def test_single_element_selection(self):
        fed_at_clients = computation_types.FederatedType(
            tf.int32, placements.CLIENTS)
        fed_at_server = computation_types.FederatedType(
            tf.int32, placements.SERVER)
        tuple_of_federated_types = computation_types.StructType(
            [fed_at_clients, fed_at_server])
        lam = building_blocks.Lambda(
            'x', tuple_of_federated_types,
            building_blocks.Selection(building_blocks.Reference(
                'x', tuple_of_federated_types),
                                      index=0))

        new_lam = form_utils._as_function_of_some_federated_subparameters(
            lam, [(0, )])
        expected_parameter_type = computation_types.at_clients((tf.int32, ))
        type_test_utils.assert_types_equivalent(
            new_lam.type_signature,
            computation_types.FunctionType(expected_parameter_type,
                                           lam.result.type_signature))
Exemplo n.º 23
0
    def test_roundtrip_no_broadcast(self):
        add_five = tensorflow_computation.tf_computation(lambda x: x + 5)
        server_data_type = computation_types.at_server(())
        client_data_type = computation_types.at_clients(tf.int32)

        @federated_computation.federated_computation(server_data_type,
                                                     client_data_type)
        def add_five_at_clients(naught_at_server, client_numbers):
            del naught_at_server
            return intrinsics.federated_map(add_five, client_numbers)

        bf = form_utils.get_broadcast_form_for_computation(add_five_at_clients)
        self.assertEqual(bf.server_data_label, 'naught_at_server')
        self.assertEqual(bf.client_data_label, 'client_numbers')
        type_test_utils.assert_types_equivalent(
            bf.compute_server_context.type_signature,
            computation_types.FunctionType((), ()))
        type_test_utils.assert_types_equivalent(
            bf.client_processing.type_signature,
            computation_types.FunctionType(((), tf.int32), tf.int32))
        self.assertEqual(6, bf.client_processing((), 1))

        round_trip_comp = form_utils.get_computation_for_broadcast_form(bf)
        type_test_utils.assert_types_equivalent(
            round_trip_comp.type_signature, add_five_at_clients.type_signature)
        self.assertEqual([10, 11, 12], round_trip_comp((), [5, 6, 7]))
Exemplo n.º 24
0
    def test_roundtrip(self):
        add = tensorflow_computation.tf_computation(lambda x, y: x + y)
        server_data_type = computation_types.at_server(tf.int32)
        client_data_type = computation_types.at_clients(tf.int32)

        @federated_computation.federated_computation(server_data_type,
                                                     client_data_type)
        def add_server_number_plus_one(server_number, client_numbers):
            one = intrinsics.federated_value(1, placements.SERVER)
            server_context = intrinsics.federated_map(add,
                                                      (one, server_number))
            client_context = intrinsics.federated_broadcast(server_context)
            return intrinsics.federated_map(add,
                                            (client_context, client_numbers))

        bf = form_utils.get_broadcast_form_for_computation(
            add_server_number_plus_one)
        self.assertEqual(bf.server_data_label, 'server_number')
        self.assertEqual(bf.client_data_label, 'client_numbers')
        type_test_utils.assert_types_equivalent(
            bf.compute_server_context.type_signature,
            computation_types.FunctionType(tf.int32, (tf.int32, )))
        self.assertEqual(2, bf.compute_server_context(1)[0])
        type_test_utils.assert_types_equivalent(
            bf.client_processing.type_signature,
            computation_types.FunctionType(((tf.int32, ), tf.int32), tf.int32))
        self.assertEqual(3, bf.client_processing((1, ), 2))

        round_trip_comp = form_utils.get_computation_for_broadcast_form(bf)
        type_test_utils.assert_types_equivalent(
            round_trip_comp.type_signature,
            add_server_number_plus_one.type_signature)
        # 2 (server data) + 1 (constant in comp) + 2 (client data) = 5 (output)
        self.assertEqual([5, 6, 7], round_trip_comp(2, [2, 3, 4]))
Exemplo n.º 25
0
    def test_federated_evaluation_quantized_conservatively(self):
        # Set up a uniform quantization encoder as the broadcaster.
        broadcaster = (
            encoding_utils.build_encoded_broadcast_process_from_model(
                TestModelQuant, _build_simple_quant_encoder(12)))
        type_test_utils.assert_types_equivalent(
            broadcaster.next.type_signature,
            _build_expected_broadcaster_next_signature())
        evaluate = federated_evaluation.build_federated_evaluation(
            TestModelQuant, broadcast_process=broadcaster)
        # Confirm that the type signature matches what is expected.
        type_test_utils.assert_types_identical(
            evaluate.type_signature,
            _build_expected_test_quant_model_eval_signature())

        def _temp_dict(temps):
            return {'temp': np.array(temps, dtype=np.float32)}

        result = evaluate(
            collections.OrderedDict(trainable=[[5.0, 10.0, 5.0, 7.0]],
                                    non_trainable=[]),
            [
                [
                    _temp_dict([1.0, 10.0, 2.0, 7.0]),
                    _temp_dict([6.0, 11.0, 5.0, 8.0])
                ],
                [_temp_dict([9.0, 12.0, 13.0, 7.0])],
                [
                    _temp_dict([1.0, 22.0, 23.0, 24.0]),
                    _temp_dict([5.0, 10.0, 5.0, 7.0])
                ],
            ])
        # This conservative quantization should not be too lossy.
        # When comparing the data examples to trainable, there are 8 times
        # where the index and value match.
        self.assertEqual(
            result,
            collections.OrderedDict(eval=collections.OrderedDict(
                num_same=8.0)))
Exemplo n.º 26
0
    def test_federated_evaluation_quantized_aggressively(self):
        # Set up a uniform quantization encoder as the broadcaster.
        broadcaster = (
            encoding_utils.build_encoded_broadcast_process_from_model(
                TestModelQuant, _build_simple_quant_encoder(2)))
        type_test_utils.assert_types_equivalent(
            broadcaster.next.type_signature,
            _build_expected_broadcaster_next_signature())
        evaluate = federated_evaluation.build_federated_evaluation(
            TestModelQuant, broadcast_process=broadcaster)
        # Confirm that the type signature matches what is expected.
        type_test_utils.assert_types_identical(
            evaluate.type_signature,
            _build_expected_test_quant_model_eval_signature())

        def _temp_dict(temps):
            return {'temp': np.array(temps, dtype=np.float32)}

        result = evaluate(
            collections.OrderedDict(trainable=[[5.0, 10.0, 5.0, 7.0]],
                                    non_trainable=[]),
            [
                [
                    _temp_dict([1.0, 10.0, 2.0, 7.0]),
                    _temp_dict([6.0, 11.0, 5.0, 8.0])
                ],
                [_temp_dict([9.0, 12.0, 13.0, 7.0])],
                [
                    _temp_dict([1.0, 22.0, 23.0, 24.0]),
                    _temp_dict([5.0, 10.0, 5.0, 7.0])
                ],
            ])
        # This very aggressive quantization should be so lossy that some of the
        # data is changed during encoding so the number that are equal between
        # the original and the final result should not be 8 as it is in the
        # conservative quantization test above.
        self.assertEqual(list(result.keys()), ['eval'])
        self.assertContainsSubset(result['eval'].keys(), ['num_same'])
        self.assertLess(result['eval']['num_same'], 8.0)
  def test_serialize_deserialize_sequence_of_namedtuples_alphabetical_order(
      self, dataset_fn):
    test_tuple_type = collections.namedtuple('TestTuple', ['a', 'b', 'c'])

    def make_test_tuple(x):
      return test_tuple_type(
          a=x * 2, b=tf.cast(x, tf.int32), c=tf.cast(x - 1, tf.float32))

    ds = tf.data.Dataset.range(5).map(make_test_tuple)
    ds_repr = dataset_fn(ds)
    element_type = computation_types.to_type(
        test_tuple_type(tf.int64, tf.int32, tf.float32))
    sequence_type = computation_types.SequenceType(element=element_type)
    value_proto, value_type = value_serialization.serialize_value(
        ds_repr, sequence_type)
    self.assertEqual(value_type, sequence_type)
    y, type_spec = value_serialization.deserialize_value(value_proto)
    type_test_utils.assert_types_equivalent(type_spec, sequence_type)
    actual_values = list(y)
    expected_values = [
        test_tuple_type(a=x * 2, b=x, c=x - 1.) for x in range(5)
    ]
    for actual, expected in zip(actual_values, expected_values):
      self.assertAllClose(actual, expected)