Example #1
0
    def test_zero_type_properties_with_zeroed_count_agg_factory(
            self, value_type):
        factory = robust.zeroing_factory(
            zeroing_norm=1.0,
            inner_agg_factory=sum_factory.SumFactory(),
            norm_order=2.0,
            zeroed_count_sum_factory=aggregator_test_utils.SumPlusOneFactory())
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(zeroing_norm=(),
                                    inner_agg=(),
                                    zeroed_count_agg=tf.int32))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(zeroing=(),
                                    zeroing_norm=robust.NORM_TF_TYPE,
                                    zeroed_count=robust.COUNT_TF_TYPE))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
  def test_clip(self, clip_mechanism):
    clip_factory = clipping_factory.HistogramClippingSumFactory(
        clip_mechanism, 1)
    self.assertIsInstance(clip_factory, factory.UnweightedAggregationFactory)
    value_type = computation_types.to_type((tf.int32, (2,)))
    process = clip_factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.FederatedType(value_type,
                                                       placements.CLIENTS)
    result_value_type = computation_types.FederatedType(value_type,
                                                        placements.SERVER)
    expected_state_type = computation_types.at_server(())
    expected_measurements_type = computation_types.at_server(())

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
Example #3
0
  def test_type_properties(self, name, value_type):
    factory = _hadamard_sum() if name == 'hd' else _dft_sum()
    value_type = computation_types.to_type(value_type)
    process = factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.at_server(
        ((), rotation.SEED_TFF_TYPE))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    type_test_utils.assert_types_equivalent(process.initialize.type_signature,
                                            expected_initialize_type)

    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict([(name, ())]))
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    type_test_utils.assert_types_equivalent(process.next.type_signature,
                                            expected_next_type)
Example #4
0
    def test_type_properties(self, value_type):
        factory = _discretization_sum()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(scale_factor=tf.float32,
                                    prior_norm_bound=tf.float32,
                                    inner_agg_process=()))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(discretize=()))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
Example #5
0
    def test_type_properties_simple(self, value_type, estimate_stddev):
        factory = _test_factory(estimate_stddev=estimate_stddev)
        process = factory.create(computation_types.to_type(value_type))
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # Inner SumFactory has no state.
        server_state_type = computation_types.at_server(())

        expected_init_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_init_type))

        expected_measurements_type = collections.OrderedDict(modclip=())
        if estimate_stddev:
            expected_measurements_type['estimated_stddev'] = tf.float32

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=computation_types.at_server(
                    expected_measurements_type)))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
    def test_type_properties(self):
        mw_type = computation_types.to_type(
            model_utils.ModelWeights(trainable=(tf.float32, tf.float32),
                                     non_trainable=tf.float32))

        finalizer = finalizers.build_apply_optimizer_finalizer(
            sgdm.build_sgdm(1.0), mw_type)
        self.assertIsInstance(finalizer, finalizers.FinalizerProcess)

        expected_param_weights_type = computation_types.at_server(mw_type)
        expected_param_update_type = computation_types.at_server(
            mw_type.trainable)
        expected_result_type = computation_types.at_server(mw_type)
        expected_state_type = computation_types.at_server(())
        expected_measurements_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        expected_initialize_type.check_equivalent_to(
            finalizer.initialize.type_signature)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                weights=expected_param_weights_type,
                update=expected_param_update_type),
            result=MeasuredProcessOutput(expected_state_type,
                                         expected_result_type,
                                         expected_measurements_type))
        expected_next_type.check_equivalent_to(finalizer.next.type_signature)
  def test_composition_type_properties(self, last_process):
    state_type = tf.float32
    values_type = tf.int32
    last_process = last_process(state_type, 0.0, values_type)
    composite_process = measured_process.chain_measured_processes(
        collections.OrderedDict(
            double=_create_test_measured_process_double(state_type, 1.0,
                                                        values_type),
            last_process=last_process))
    self.assertIsInstance(composite_process, measured_process.MeasuredProcess)

    expected_state_type = computation_types.at_server(
        collections.OrderedDict(double=state_type, last_process=state_type))
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        composite_process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    param_value_type = computation_types.at_clients(values_type)
    result_value_type = computation_types.at_server(values_type)
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(
            double=collections.OrderedDict(a=tf.int32),
            last_process=last_process.next.type_signature.result.measurements
            .member))
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, values=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        composite_process.next.type_signature.is_equivalent_to(
            expected_next_type))
Example #8
0
  def test_type_properties_simple(self):
    value_type = computation_types.to_type((tf.int32, (2,)))
    agg_factory = modular_clipping_factory.ModularClippingSumFactory(
        clip_range_lower=-2, clip_range_upper=2)
    process = agg_factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    # Inner SumFactory has no state.
    server_state_type = computation_types.at_server(())

    expected_init_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(expected_init_type))

    expected_measurements_type = collections.OrderedDict(modclip=())

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=computation_types.at_server(
                expected_measurements_type)))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
 def test_fails_conflicting_binding_in_parameter_and_result(self):
     t1 = computation_types.FunctionType(
         computation_types.AbstractType('T'),
         computation_types.AbstractType('T'))
     t2 = computation_types.FunctionType(tf.int32, tf.float32)
     with self.assertRaises(type_analysis.UnassignableConcreteTypesError):
         type_analysis.check_concrete_instance_of(t2, t1)
Example #10
0
 def test_federated_secure_select(self):
     uri = intrinsic_defs.FEDERATED_SECURE_SELECT.uri
     comp = building_blocks.Intrinsic(
         uri,
         computation_types.FunctionType(
             [
                 computation_types.at_clients(tf.int32),  # client_keys
                 computation_types.at_server(tf.int32),  # max_key
                 computation_types.at_server(tf.float32),  # server_state
                 computation_types.FunctionType([tf.float32, tf.int32],
                                                tf.float32)  # select_fn
             ],
             computation_types.at_clients(
                 computation_types.SequenceType(tf.float32))))
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     # First without secure intrinsics shouldn't modify anything.
     reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
         comp)
     self.assertFalse(modified)
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     self.assert_types_identical(comp.type_signature,
                                 reduced.type_signature)
     # Now replace bodies including secure intrinsics.
     reduced, modified = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies(
         comp)
     self.assertTrue(modified)
     self.assert_types_identical(comp.type_signature,
                                 reduced.type_signature)
     self.assertGreater(
         _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SELECT.uri), 0)
Example #11
0
    def test_type_properties_constant_bounds(self, value_type, upper_bound,
                                             lower_bound, measurements_dtype):
        secure_sum_f = secure.SecureSumFactory(
            upper_bound_threshold=upper_bound,
            lower_bound_threshold=lower_bound)
        self.assertIsInstance(secure_sum_f,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = secure_sum_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = _measurements_type(measurements_dtype)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
 def test_intrinsic_class_succeeds_simple_federated_map(self):
     simple_function = computation_types.FunctionType(tf.int32, tf.float32)
     federated_arg = computation_types.FederatedType(
         simple_function.parameter, placements.CLIENTS)
     federated_result = computation_types.FederatedType(
         simple_function.result, placements.CLIENTS)
     federated_map_concrete_type = computation_types.FunctionType(
         computation_types.StructType((simple_function, federated_arg)),
         federated_result)
     concrete_federated_map = building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type)
     self.assertIsInstance(concrete_federated_map,
                           building_blocks.Intrinsic)
     self.assertEqual(
         str(concrete_federated_map.type_signature),
         '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)')
     self.assertEqual(concrete_federated_map.uri, 'federated_map')
     self.assertEqual(concrete_federated_map.compact_representation(),
                      'federated_map')
     concrete_federated_map_proto = concrete_federated_map.proto
     self.assertEqual(
         type_serialization.deserialize_type(
             concrete_federated_map_proto.type),
         concrete_federated_map.type_signature)
     self.assertEqual(
         concrete_federated_map_proto.WhichOneof('computation'),
         'intrinsic')
     self.assertEqual(concrete_federated_map_proto.intrinsic.uri,
                      concrete_federated_map.uri)
     self._serialize_deserialize_roundtrip_test(concrete_federated_map)
Example #13
0
  def test_type_properties_unweighted(self, value_type):
    value_type = computation_types.to_type(value_type)

    factory_ = mean.UnweightedMeanFactory()
    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    process = factory_.create(value_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(((), ()))
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=(), mean_count=()))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
 def test_repr(self):
   self.assertEqual(
       repr(computation_types.FunctionType(tf.int32, tf.bool)),
       'FunctionType(TensorType(tf.int32), TensorType(tf.bool))')
   self.assertEqual(
       repr(computation_types.FunctionType(None, tf.bool)),
       'FunctionType(None, TensorType(tf.bool))')
Example #15
0
    def test_roundtrip(self):
        add = tensorflow_computation.tf_computation(lambda x, y: x + y)
        server_data_type = computation_types.at_server(tf.int32)
        client_data_type = computation_types.at_clients(tf.int32)

        @federated_computation.federated_computation(server_data_type,
                                                     client_data_type)
        def add_server_number_plus_one(server_number, client_numbers):
            one = intrinsics.federated_value(1, placements.SERVER)
            server_context = intrinsics.federated_map(add,
                                                      (one, server_number))
            client_context = intrinsics.federated_broadcast(server_context)
            return intrinsics.federated_map(add,
                                            (client_context, client_numbers))

        bf = form_utils.get_broadcast_form_for_computation(
            add_server_number_plus_one)
        self.assertEqual(bf.server_data_label, 'server_number')
        self.assertEqual(bf.client_data_label, 'client_numbers')
        type_test_utils.assert_types_equivalent(
            bf.compute_server_context.type_signature,
            computation_types.FunctionType(tf.int32, (tf.int32, )))
        self.assertEqual(2, bf.compute_server_context(1)[0])
        type_test_utils.assert_types_equivalent(
            bf.client_processing.type_signature,
            computation_types.FunctionType(((tf.int32, ), tf.int32), tf.int32))
        self.assertEqual(3, bf.client_processing((1, ), 2))

        round_trip_comp = form_utils.get_computation_for_broadcast_form(bf)
        type_test_utils.assert_types_equivalent(
            round_trip_comp.type_signature,
            add_server_number_plus_one.type_signature)
        # 2 (server data) + 1 (constant in comp) + 2 (client data) = 5 (output)
        self.assertEqual([5, 6, 7], round_trip_comp(2, [2, 3, 4]))
Example #16
0
    def test_clip_type_properties_weighted(self, value_type, weight_type):
        factory = _concat_mean()
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.to_type(weight_type)
        process = factory.create(value_type, weight_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # State comes from the inner MeanFactory.
        server_state_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        # Measurements come from the inner mean factory.
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(mean_value=(), mean_weight=()))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type),
                weight=computation_types.at_clients(weight_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
Example #17
0
    def test_type_properties(self, modulus, value_type, symmetric_range):
        factory_ = secure.SecureModularSumFactory(
            modulus=modulus, symmetric_range=symmetric_range)
        self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = factory_.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = expected_state_type

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
Example #18
0
    def test_type_properties(self, value_type):
        sum_f = sum_factory.SumFactory()
        self.assertIsInstance(sum_f, factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = sum_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        param_value_type = computation_types.FederatedType(
            value_type, placements.CLIENTS)
        result_value_type = computation_types.FederatedType(
            value_type, placements.SERVER)
        expected_state_type = computation_types.FederatedType(
            (), placements.SERVER)
        expected_measurements_type = computation_types.FederatedType(
            (), placements.SERVER)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(state=expected_state_type,
                                              value=param_value_type),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type,
                expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Example #19
0
    def test_concat_type_properties_unweighted(self, value_type):
        factory = _concat_sum()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # Inner SumFactory has no state.
        server_state_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        # Inner SumFactory has no measurements.
        expected_measurements_type = computation_types.at_server(())
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
    def test_type_properties(self, metric_finalizers, unfinalized_metrics):
        aggregate_factory = aggregation_factory.SumThenFinalizeFactory()
        self.assertIsInstance(aggregate_factory,
                              factory.UnweightedAggregationFactory)
        local_unfinalized_metrics_type = type_conversions.type_from_tensors(
            unfinalized_metrics)
        process = aggregate_factory.create(metric_finalizers,
                                           local_unfinalized_metrics_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.FederatedType(
            ((), local_unfinalized_metrics_type), placements.SERVER)
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        finalized_metrics_type = _get_finalized_metrics_type(
            metric_finalizers, unfinalized_metrics)
        result_value_type = computation_types.FederatedType(
            (finalized_metrics_type, finalized_metrics_type),
            placements.SERVER)
        measurements_type = computation_types.FederatedType((),
                                                            placements.SERVER)
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                unfinalized_metrics=computation_types.FederatedType(
                    local_unfinalized_metrics_type, placements.CLIENTS)),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type, measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
class VisitPreorderTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters([
        ('abstract_type', computation_types.AbstractType('T'), 1),
        ('nested_function_type',
         computation_types.FunctionType(
             computation_types.FunctionType(
                 computation_types.FunctionType(tf.int32, tf.int32), tf.int32),
             tf.int32), 7),
        ('named_tuple_type',
         computation_types.StructType(
             [tf.int32, tf.bool,
              computation_types.SequenceType(tf.int32)]), 5),
        ('placement_type', computation_types.PlacementType(), 1),
    ])
    # pyformat: enable
    def test_preorder_call_count(self, type_signature, expected_count):
        class Counter(object):
            k = 0

        def _count_hits(given_type, arg):
            del given_type  # Unused.
            Counter.k += 1
            return arg

        type_transformations.visit_preorder(type_signature, _count_hits, None)
        actual_count = Counter.k
        self.assertEqual(actual_count, expected_count)
Example #22
0
 def test_create_xla_tff_computation_with_reordered_tensor_indexes(self):
     builder = xla_client.XlaBuilder('comp')
     tensor_shape_1 = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.int32), (10, 1))
     param_1 = xla_client.ops.Parameter(builder, 0, tensor_shape_1)
     tensor_shape_2 = xla_client.Shape.array_shape(
         xla_client.dtype_to_etype(np.int32), (1, 20))
     param_2 = xla_client.ops.Parameter(builder, 1, tensor_shape_2)
     xla_client.ops.Dot(param_1, param_2)
     xla_comp = builder.build()
     comp_pb_1 = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1],
         computation_types.FunctionType(
             ((np.int32, (10, 1)), (np.int32, (1, 20))), (np.int32, (
                 10,
                 20,
             ))))
     self.assertIsInstance(comp_pb_1, pb.Computation)
     self.assertEqual(comp_pb_1.WhichOneof('computation'), 'xla')
     type_spec_1 = type_serialization.deserialize_type(comp_pb_1.type)
     self.assertEqual(str(type_spec_1),
                      '(<int32[10,1],int32[1,20]> -> int32[10,20])')
     comp_pb_2 = xla_serialization.create_xla_tff_computation(
         xla_comp, [1, 0],
         computation_types.FunctionType(
             ((np.int32, (1, 20)), (np.int32, (10, 1))), (np.int32, (
                 10,
                 20,
             ))))
     self.assertIsInstance(comp_pb_2, pb.Computation)
     self.assertEqual(comp_pb_2.WhichOneof('computation'), 'xla')
     type_spec_2 = type_serialization.deserialize_type(comp_pb_2.type)
     self.assertEqual(str(type_spec_2),
                      '(<int32[1,20],int32[10,1]> -> int32[10,20])')
Example #23
0
    def test_clip_type_properties_simple(self, value_type):
        factory = _clipped_sum()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(clipping_norm=(),
                                    inner_agg=(),
                                    clipped_count_agg=()))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(clipping=(),
                                    clipping_norm=robust.NORM_TF_TYPE,
                                    clipped_count=robust.COUNT_TF_TYPE))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Example #24
0
  def test_type_properties(self, value_type):
    factory = stochastic_discretization.StochasticDiscretizationFactory(
        step_size=0.1,
        inner_agg_factory=_measurement_aggregator,
        distortion_aggregation_factory=mean.UnweightedMeanFactory())
    value_type = computation_types.to_type(value_type)
    quantize_type = type_conversions.structure_from_tensor_type_tree(
        lambda x: (tf.int32, x.shape), value_type)
    process = factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.StructType([('step_size', tf.float32),
                                                      ('inner_agg_process', ())
                                                     ])
    server_state_type = computation_types.at_server(server_state_type)
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    type_test_utils.assert_types_equivalent(process.initialize.type_signature,
                                            expected_initialize_type)

    expected_measurements_type = computation_types.StructType([
        ('stochastic_discretization', quantize_type), ('distortion', tf.float32)
    ])
    expected_measurements_type = computation_types.at_server(
        expected_measurements_type)
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    type_test_utils.assert_types_equivalent(process.next.type_signature,
                                            expected_next_type)
Example #25
0
    def test_roundtrip_no_broadcast(self):
        add_five = tensorflow_computation.tf_computation(lambda x: x + 5)
        server_data_type = computation_types.at_server(())
        client_data_type = computation_types.at_clients(tf.int32)

        @federated_computation.federated_computation(server_data_type,
                                                     client_data_type)
        def add_five_at_clients(naught_at_server, client_numbers):
            del naught_at_server
            return intrinsics.federated_map(add_five, client_numbers)

        bf = form_utils.get_broadcast_form_for_computation(add_five_at_clients)
        self.assertEqual(bf.server_data_label, 'naught_at_server')
        self.assertEqual(bf.client_data_label, 'client_numbers')
        type_test_utils.assert_types_equivalent(
            bf.compute_server_context.type_signature,
            computation_types.FunctionType((), ()))
        type_test_utils.assert_types_equivalent(
            bf.client_processing.type_signature,
            computation_types.FunctionType(((), tf.int32), tf.int32))
        self.assertEqual(6, bf.client_processing((), 1))

        round_trip_comp = form_utils.get_computation_for_broadcast_form(bf)
        type_test_utils.assert_types_equivalent(
            round_trip_comp.type_signature, add_five_at_clients.type_signature)
        self.assertEqual([10, 11, 12], round_trip_comp((), [5, 6, 7]))
Example #26
0
 def test_equality(self):
     t1 = computation_types.FunctionType(tf.int32, tf.bool)
     t2 = computation_types.FunctionType(tf.int32, tf.bool)
     t3 = computation_types.FunctionType(tf.int32, tf.int32)
     t4 = computation_types.FunctionType(tf.bool, tf.bool)
     self.assertEqual(t1, t2)
     self.assertNotEqual(t1, t3)
     self.assertNotEqual(t1, t4)
Example #27
0
 def test_returns_true_for_calls_with_no_arguments(self):
     function_type_1 = computation_types.FunctionType(None, tf.int32)
     fn_1 = building_blocks.Reference('a', function_type_1)
     comp_1 = building_blocks.Call(fn_1)
     function_type_2 = computation_types.FunctionType(None, tf.int32)
     fn_2 = building_blocks.Reference('a', function_type_2)
     comp_2 = building_blocks.Call(fn_2)
     self.assertTrue(tree_analysis.trees_equal(comp_1, comp_2))
Example #28
0
class FnToBuildingBlockTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters((
        'nested_fn_same', lambda f, x: f(f(x)),
        computation_types.StructType(
            [('f', computation_types.FunctionType(tf.int32, tf.int32)),
             ('x', tf.int32)]),
        '(FEDERATED_foo -> (let fc_FEDERATED_symbol_0=FEDERATED_foo.f(FEDERATED_foo.x),fc_FEDERATED_symbol_1=FEDERATED_foo.f(fc_FEDERATED_symbol_0) in fc_FEDERATED_symbol_1))'
    ), ('nested_fn_different', lambda f, g, x: f(g(x)),
        computation_types.StructType(
            [('f', computation_types.FunctionType(tf.int32, tf.int32)),
             ('g', computation_types.FunctionType(tf.int32, tf.int32)),
             ('x', tf.int32)]),
        '(FEDERATED_foo -> (let fc_FEDERATED_symbol_0=FEDERATED_foo.g(FEDERATED_foo.x),fc_FEDERATED_symbol_1=FEDERATED_foo.f(fc_FEDERATED_symbol_0) in fc_FEDERATED_symbol_1))'
        ), ('selection', lambda x:
            (x[1], x[0]), computation_types.StructType([tf.int32, tf.int32]),
            '(FEDERATED_foo -> <FEDERATED_foo[1],FEDERATED_foo[0]>)'),
                                    ('constant', lambda: 'stuff', None,
                                     '( -> (let fc_FEDERATED_symbol_0=comp#'))
    # pyformat: enable
    def test_returns_result(self, fn, parameter_type, fn_str):
        parameter_name = 'foo' if parameter_type is not None else None
        result, _ = _federated_computation_serializer(fn, parameter_name,
                                                      parameter_type)
        self.assertStartsWith(str(result), fn_str)

    # pyformat: disable
    @parameterized.named_parameters(
        ('tuple', lambda x:
         (x[1], x[0]), computation_types.StructType([tf.int32, tf.float32]),
         computation_types.StructWithPythonType([(None, tf.float32),
                                                 (None, tf.int32)], tuple)),
        ('list', lambda x: [x[1], x[0]],
         computation_types.StructType([tf.int32, tf.float32]),
         computation_types.StructWithPythonType([(None, tf.float32),
                                                 (None, tf.int32)], list)),
        ('odict', lambda x: collections.OrderedDict([('A', x[1]),
                                                     ('B', x[0])]),
         computation_types.StructType([tf.int32, tf.float32]),
         computation_types.StructWithPythonType([('A', tf.float32),
                                                 ('B', tf.int32)],
                                                collections.OrderedDict)),
        ('namedtuple', lambda x: TestNamedTuple(x=x[1], y=x[0]),
         computation_types.StructType([tf.int32, tf.float32]),
         computation_types.StructWithPythonType([('x', tf.float32),
                                                 ('y', tf.int32)],
                                                TestNamedTuple)),
    )
    # pyformat: enable
    def test_returns_result_with_py_container(self, fn, parameter_type,
                                              exepcted_result_type):
        _, type_signature = _federated_computation_serializer(
            fn, 'foo', parameter_type)
        self.assertIs(type(type_signature.result), type(exepcted_result_type))
        self.assertIs(type_signature.result.python_container,
                      exepcted_result_type.python_container)
        self.assertEqual(type_signature.result, exepcted_result_type)
    def test_construction(self, weighted):
        aggregation_factory = (mean.MeanFactory()
                               if weighted else sum_factory.SumFactory())
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            model_update_aggregation_factory=aggregation_factory)

        if weighted:
            aggregate_state = collections.OrderedDict(value_sum_process=(),
                                                      weight_sum_process=())
            aggregate_metrics = collections.OrderedDict(mean_value=(),
                                                        mean_weight=())
        else:
            aggregate_state = ()
            aggregate_metrics = ()

        server_state_type = computation_types.FederatedType(
            optimizer_utils.ServerState(model=model_utils.ModelWeights(
                trainable=[
                    computation_types.TensorType(tf.float32, [2, 1]),
                    computation_types.TensorType(tf.float32)
                ],
                non_trainable=[computation_types.TensorType(tf.float32)]),
                                        optimizer_state=[tf.int64],
                                        delta_aggregate_state=aggregate_state,
                                        model_broadcast_state=()),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=None,
                                           result=server_state_type),
            iterative_process.initialize.type_signature)

        dataset_type = computation_types.FederatedType(
            computation_types.SequenceType(
                collections.OrderedDict(
                    x=computation_types.TensorType(tf.float32, [None, 2]),
                    y=computation_types.TensorType(tf.float32, [None, 1]))),
            placements.CLIENTS)
        metrics_type = computation_types.FederatedType(
            collections.OrderedDict(
                broadcast=(),
                aggregation=aggregate_metrics,
                train=collections.OrderedDict(
                    loss=computation_types.TensorType(tf.float32),
                    num_examples=computation_types.TensorType(tf.int32)),
                stat=collections.OrderedDict(
                    num_examples=computation_types.TensorType(tf.float32))),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=collections.OrderedDict(
                server_state=server_state_type,
                federated_dataset=dataset_type,
            ),
                                           result=(server_state_type,
                                                   metrics_type)),
            iterative_process.next.type_signature)
Example #30
0
 def test_passes_with_federated_map(self):
     intrinsic = building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_MAP.uri,
         computation_types.FunctionType([
             computation_types.FunctionType(tf.int32, tf.float32),
             computation_types.FederatedType(tf.int32, placements.CLIENTS)
         ], computation_types.FederatedType(tf.float32,
                                            placements.CLIENTS)))
     tree_analysis.check_contains_only_reducible_intrinsics(intrinsic)