Esempio n. 1
0
    def test_type_properties(self, value_type, factory_cons):
        factory = factory_cons()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            ((),
             collections.OrderedDict(value_sum_process=(),
                                     weight_sum_process=())))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type),
                weight=computation_types.at_clients(tf.float32)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 2
0
    def test_concat_type_properties_unweighted(self, value_type):
        factory = _concat_sum()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # Inner SumFactory has no state.
        server_state_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        # Inner SumFactory has no measurements.
        expected_measurements_type = computation_types.at_server(())
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
Esempio n. 3
0
        def next_fn(state, value):
            clip_lower = intrinsics.federated_value(self._clip_range_lower,
                                                    placements.SERVER)
            clip_upper = intrinsics.federated_value(self._clip_range_upper,
                                                    placements.SERVER)

            # Modular clip values before aggregation.
            clipped_value = intrinsics.federated_map(
                modular_clip_by_value_fn,
                (value, intrinsics.federated_broadcast(clip_lower),
                 intrinsics.federated_broadcast(clip_upper)))

            inner_agg_output = inner_agg_next(state, clipped_value)

            # Clip the aggregate to the same range again (not considering summands).
            clipped_agg_output_result = intrinsics.federated_map(
                modular_clip_by_value_fn,
                (inner_agg_output.result, clip_lower, clip_upper))

            measurements = collections.OrderedDict(
                modclip=inner_agg_output.measurements)

            if self._estimate_stddev:
                estimate = intrinsics.federated_map(
                    estimator_fn,
                    (clipped_agg_output_result, clip_lower, clip_upper))
                measurements['estimated_stddev'] = estimate

            return measured_process.MeasuredProcessOutput(
                state=inner_agg_output.state,
                result=clipped_agg_output_result,
                measurements=intrinsics.federated_zip(measurements))
Esempio n. 4
0
    def test_type_properties_constant_bounds(self, value_type, upper_bound,
                                             lower_bound, measurements_dtype):
        secure_sum_f = secure_factory.SecureSumFactory(
            upper_bound_threshold=upper_bound,
            lower_bound_threshold=lower_bound)
        self.assertIsInstance(secure_sum_f,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = secure_sum_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = _measurements_type(measurements_dtype)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 5
0
        def next_fn(state, value):
            server_scale_factor = state['scale_factor']
            client_scale_factor = intrinsics.federated_broadcast(
                server_scale_factor)
            server_prior_norm_bound = state['prior_norm_bound']
            prior_norm_bound = intrinsics.federated_broadcast(
                server_prior_norm_bound)

            discretized_value = intrinsics.federated_map(
                discretize_fn, (value, client_scale_factor, prior_norm_bound))

            inner_state = state['inner_agg_process']
            inner_agg_output = inner_agg_process.next(inner_state,
                                                      discretized_value)

            undiscretized_agg_value = intrinsics.federated_map(
                undiscretize_fn,
                (inner_agg_output.result, server_scale_factor))

            new_state = collections.OrderedDict(
                scale_factor=server_scale_factor,
                prior_norm_bound=server_prior_norm_bound,
                inner_agg_process=inner_agg_output.state)
            measurements = collections.OrderedDict(
                discretize=inner_agg_output.measurements)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=undiscretized_agg_value,
                measurements=intrinsics.federated_zip(measurements))
Esempio n. 6
0
    def next_fn(state, value):
      server_step_size = state['step_size']
      client_step_size = intrinsics.federated_broadcast(server_step_size)

      discretized_value = intrinsics.federated_map(discretize_fn,
                                                   (value, client_step_size))

      inner_state = state['inner_agg_process']
      inner_agg_output = inner_agg_process.next(inner_state, discretized_value)

      undiscretized_agg_value = intrinsics.federated_map(
          undiscretize_fn, (inner_agg_output.result, server_step_size))

      new_state = collections.OrderedDict(
          step_size=server_step_size, inner_agg_process=inner_agg_output.state)
      measurements = collections.OrderedDict(
          deterministic_discretization=inner_agg_output.measurements)

      if self._distortion_aggregation_factory is not None:
        distortions = intrinsics.federated_map(distortion_measurement_fn,
                                               (value, client_step_size))
        aggregate_distortion = distortion_aggregation_process.next(
            distortion_aggregation_process.initialize(), distortions).result
        measurements['distortion'] = aggregate_distortion

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(new_state),
          result=undiscretized_agg_value,
          measurements=intrinsics.federated_zip(measurements))
Esempio n. 7
0
    def test_type_properties(self, value_type):
        sum_f = test_utils.SumPlusOneFactory()
        self.assertIsInstance(sum_f, factory.AggregationProcessFactory)
        value_type = computation_types.to_type(value_type)
        process = sum_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        param_value_type = computation_types.FederatedType(
            value_type, placements.CLIENTS)
        result_value_type = computation_types.FederatedType(
            value_type, placements.SERVER)
        expected_state_type = computation_types.FederatedType(
            tf.int32, placements.SERVER)
        expected_measurements_type = computation_types.FederatedType(
            tf.int32, placements.SERVER)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(state=expected_state_type,
                                              value=param_value_type),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type,
                expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 8
0
    def test_type_properties(self, value_type, weight_type):
        mean_f = mean_factory.MeanFactory()
        self.assertIsInstance(mean_f, factory.WeightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.to_type(weight_type)
        process = mean_f.create_weighted(value_type, weight_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        param_value_type = computation_types.FederatedType(
            value_type, placements.CLIENTS)
        result_value_type = computation_types.FederatedType(
            value_type, placements.SERVER)
        expected_state_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=param_value_type,
                weight=computation_types.at_clients(weight_type)),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type,
                expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 9
0
 def stateful_broadcast(state, value):
     test_metrics = intrinsics.federated_value(
         3.0, placements.SERVER)
     return measured_process_lib.MeasuredProcessOutput(
         state=state,
         result=intrinsics.federated_broadcast(value),
         measurements=test_metrics)
 def test_non_server_placed_next_result_raises(self):
     init_fn = computations.federated_computation(
         lambda: intrinsics.federated_value(0, placements.SERVER))
     next_fn = computations.federated_computation(SERVER_INT, CLIENTS_INT)(
         lambda x, y: measured_process.MeasuredProcessOutput(x, y, x))
     with self.assertRaises(TypeError):
         aggregation_process.AggregationProcess(init_fn, next_fn)
Esempio n. 11
0
  def test_type_properties(self, value_type):
    factory = stochastic_discretization.StochasticDiscretizationFactory(
        step_size=0.1,
        inner_agg_factory=_measurement_aggregator,
        distortion_aggregation_factory=mean.UnweightedMeanFactory())
    value_type = computation_types.to_type(value_type)
    quantize_type = type_conversions.structure_from_tensor_type_tree(
        lambda x: (tf.int32, x.shape), value_type)
    process = factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.StructType([('step_size', tf.float32),
                                                      ('inner_agg_process', ())
                                                     ])
    server_state_type = computation_types.at_server(server_state_type)
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    type_test_utils.assert_types_equivalent(process.initialize.type_signature,
                                            expected_initialize_type)

    expected_measurements_type = computation_types.StructType([
        ('stochastic_discretization', quantize_type), ('distortion', tf.float32)
    ])
    expected_measurements_type = computation_types.at_server(
        expected_measurements_type)
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    type_test_utils.assert_types_equivalent(process.next.type_signature,
                                            expected_next_type)
 def test_single_param_next_raises(self):
     init_fn = computations.federated_computation(
         lambda: intrinsics.federated_value(0, placements.SERVER))
     next_fn = computations.federated_computation(SERVER_INT)(
         lambda x: measured_process.MeasuredProcessOutput(x, x, x))
     with self.assertRaises(TypeError):
         aggregation_process.AggregationProcess(init_fn, next_fn)
Esempio n. 13
0
    def test_type_properties_simple(self, value_type, estimate_stddev):
        factory = _test_factory(estimate_stddev=estimate_stddev)
        process = factory.create(computation_types.to_type(value_type))
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # Inner SumFactory has no state.
        server_state_type = computation_types.at_server(())

        expected_init_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_init_type))

        expected_measurements_type = collections.OrderedDict(modclip=())
        if estimate_stddev:
            expected_measurements_type['estimated_stddev'] = tf.float32

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=computation_types.at_server(
                    expected_measurements_type)))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 14
0
    def test_clip_type_properties_with_clipped_count_agg_factory(
            self, value_type):
        factory = robust.clipping_factory(
            clipping_norm=1.0,
            inner_agg_factory=sum_factory.SumFactory(),
            clipped_count_sum_factory=aggregator_test_utils.SumPlusOneFactory(
            ))
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(clipping_norm=(),
                                    inner_agg=(),
                                    clipped_count_agg=tf.int32))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(clipping=(),
                                    clipping_norm=robust.NORM_TF_TYPE,
                                    clipped_count=robust.COUNT_TF_TYPE))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
    def test_type_properties(self, metric_finalizers, unfinalized_metrics):
        aggregate_factory = aggregation_factory.SumThenFinalizeFactory()
        self.assertIsInstance(aggregate_factory,
                              factory.UnweightedAggregationFactory)
        local_unfinalized_metrics_type = type_conversions.type_from_tensors(
            unfinalized_metrics)
        process = aggregate_factory.create(metric_finalizers,
                                           local_unfinalized_metrics_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.FederatedType(
            ((), local_unfinalized_metrics_type), placements.SERVER)
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        finalized_metrics_type = _get_finalized_metrics_type(
            metric_finalizers, unfinalized_metrics)
        result_value_type = computation_types.FederatedType(
            (finalized_metrics_type, finalized_metrics_type),
            placements.SERVER)
        measurements_type = computation_types.FederatedType((),
                                                            placements.SERVER)
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                unfinalized_metrics=computation_types.FederatedType(
                    local_unfinalized_metrics_type, placements.CLIENTS)),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type, measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 16
0
 def next_comp(state):
     return intrinsics.federated_zip(
         measured_process.MeasuredProcessOutput(
             state=state,
             result=intrinsics.federated_value(0, placements.SERVER),
             measurements=intrinsics.federated_value(
                 (), placements.SERVER))),
Esempio n. 17
0
    def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None):
        clipping_norm_state, agg_state, clipped_count_state = state

        clipping_norm = clipping_norm_process.report(clipping_norm_state)

        clipped_value, global_norm, was_clipped = intrinsics.federated_map(
            clip_fn, (value, intrinsics.federated_broadcast(clipping_norm)))

        new_clipping_norm_state = clipping_norm_process.next(
            clipping_norm_state, global_norm)

        if weight is None:
            agg_output = inner_agg_process.next(agg_state, clipped_value)
        else:
            agg_output = inner_agg_process.next(agg_state, clipped_value,
                                                weight)

        clipped_count_output = clipped_count_agg_process.next(
            clipped_count_state, was_clipped)

        new_state = collections.OrderedDict([
            (prefix('ing_norm'), new_clipping_norm_state),
            ('inner_agg', agg_output.state),
            (prefix('ed_count_agg'), clipped_count_output.state)
        ])
        measurements = collections.OrderedDict([
            (prefix('ing'), agg_output.measurements),
            (prefix('ing_norm'), clipping_norm),
            (prefix('ed_count'), clipped_count_output.result)
        ])

        return measured_process.MeasuredProcessOutput(
            state=intrinsics.federated_zip(new_state),
            result=agg_output.result,
            measurements=intrinsics.federated_zip(measurements))
Esempio n. 18
0
    def next_fn(state,
                unfinalized_metrics) -> measured_process.MeasuredProcessOutput:
      inner_summation_state, unfinalized_metrics_accumulators = state

      inner_summation_output = inner_summation_process.next(
          inner_summation_state, unfinalized_metrics)
      summed_unfinalized_metrics = inner_summation_output.result
      inner_summation_state = inner_summation_output.state

      @tensorflow_computation.tf_computation(local_unfinalized_metrics_type,
                                             local_unfinalized_metrics_type)
      def add_unfinalized_metrics(unfinalized_metrics,
                                  summed_unfinalized_metrics):
        return tf.nest.map_structure(tf.add, unfinalized_metrics,
                                     summed_unfinalized_metrics)

      unfinalized_metrics_accumulators = intrinsics.federated_map(
          add_unfinalized_metrics,
          (unfinalized_metrics_accumulators, summed_unfinalized_metrics))

      finalizer_computation = _build_finalizer_computation(
          metric_finalizers, local_unfinalized_metrics_type)

      current_round_metrics = intrinsics.federated_map(
          finalizer_computation, summed_unfinalized_metrics)
      total_rounds_metrics = intrinsics.federated_map(
          finalizer_computation, unfinalized_metrics_accumulators)

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(
              (inner_summation_state, unfinalized_metrics_accumulators)),
          result=intrinsics.federated_zip(
              (current_round_metrics, total_rounds_metrics)),
          measurements=inner_summation_output.measurements)
Esempio n. 19
0
    def test_zero_type_properties_weighted(self, value_type, weight_type):
        factory = _zeroed_mean()
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.to_type(weight_type)
        process = factory.create(value_type, weight_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        mean_state_type = collections.OrderedDict(value_sum_process=(),
                                                  weight_sum_process=())
        server_state_type = computation_types.at_server(
            collections.OrderedDict(zeroing_norm=(),
                                    inner_agg=mean_state_type,
                                    zeroed_count_agg=()))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(zeroing=collections.OrderedDict(
                mean_value=(), mean_weight=()),
                                    zeroing_norm=robust_factory.NORM_TF_TYPE,
                                    zeroed_count=robust_factory.COUNT_TF_TYPE))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type),
                weight=computation_types.at_clients(weight_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 20
0
  def test_type_properties_simple(self):
    value_type = computation_types.to_type((tf.int32, (2,)))
    agg_factory = modular_clipping_factory.ModularClippingSumFactory(
        clip_range_lower=-2, clip_range_upper=2)
    process = agg_factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    # Inner SumFactory has no state.
    server_state_type = computation_types.at_server(())

    expected_init_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(expected_init_type))

    expected_measurements_type = collections.OrderedDict(modclip=())

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=computation_types.at_server(
                expected_measurements_type)))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 21
0
    def next_fn(state, value, weight):
      # Client computation.
      weighted_value = intrinsics.federated_map(_mul, (value, weight))

      # Inner aggregations.
      value_output = value_sum_process.next(state['value_sum_process'],
                                            weighted_value)
      weight_output = weight_sum_process.next(state['weight_sum_process'],
                                              weight)

      # Server computation.
      weighted_mean_value = intrinsics.federated_map(
          _div_no_nan if self._no_nan_division else _div,
          (value_output.result, weight_output.result))

      # Output preparation.
      state = collections.OrderedDict(
          value_sum_process=value_output.state,
          weight_sum_process=weight_output.state)
      measurements = collections.OrderedDict(
          value_sum_process=value_output.measurements,
          weight_sum_process=weight_output.measurements)
      return measured_process.MeasuredProcessOutput(
          intrinsics.federated_zip(state), weighted_mean_value,
          intrinsics.federated_zip(measurements))
Esempio n. 22
0
    def next_fn(global_state, value, weight):
        """Defines next_fn for MeasuredProcess."""
        # Weighted aggregation is not supported.
        # TODO(b/140236959): Add an assertion that weight is None here, so the
        # contract of this method is better established. Will likely cause some
        # downstream breaks.
        del weight

        sample_params = intrinsics.federated_map(derive_sample_params,
                                                 global_state)
        client_sample_params = intrinsics.federated_broadcast(sample_params)
        preprocessed_record = intrinsics.federated_map(
            preprocess_record, (client_sample_params, value))
        agg_result = intrinsics.federated_aggregate(preprocessed_record,
                                                    zero(), accumulate, merge,
                                                    report)

        updated_state, result = intrinsics.federated_map(
            post_process, (agg_result, global_state))

        metrics = intrinsics.federated_map(derive_metrics, updated_state)

        return measured_process.MeasuredProcessOutput(state=updated_state,
                                                      result=result,
                                                      measurements=metrics)
  def test_clip_type_properties_simple(self, value_type):
    factory = _clipped_sum()
    value_type = computation_types.to_type(value_type)
    process = factory.create_unweighted(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.at_server(
        collections.OrderedDict(
            clipping_norm=(), inner_agg=(), clipped_count_agg=()))
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(
            agg_process=(),
            clipping_norm=clipping_factory.NORM_TF_TYPE,
            clipped_count=clipping_factory.COUNT_TF_TYPE))
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 24
0
  def test_type_properties_unweighted(self, value_type):
    value_type = computation_types.to_type(value_type)

    factory_ = mean.UnweightedMeanFactory()
    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    process = factory_.create(value_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(())
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=()))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 25
0
    def next_fn(state, value):
        encode_params, decode_before_sum_params, decode_after_sum_params = (
            intrinsics.federated_map(get_params_fn, state))
        encode_params = intrinsics.federated_broadcast(encode_params)
        decode_before_sum_params = intrinsics.federated_broadcast(
            decode_before_sum_params)

        encoded_values = intrinsics.federated_map(
            encode_fn, [value, encode_params, decode_before_sum_params])

        aggregated_values = intrinsics.federated_aggregate(
            encoded_values, zero_fn(), accumulate_fn, merge_fn, report_fn)

        decoded_values = intrinsics.federated_map(
            decode_after_sum_fn,
            [aggregated_values.values, decode_after_sum_params])

        updated_state = intrinsics.federated_map(
            update_state_fn, [state, aggregated_values.state_update_tensors])

        empty_metrics = intrinsics.federated_value((), placements.SERVER)
        return measured_process.MeasuredProcessOutput(
            state=updated_state,
            result=decoded_values,
            measurements=empty_metrics)
Esempio n. 26
0
        def next_fn_impl(state, value, weight=None):
            zeroing_norm_state, agg_state, zeroed_count_state = state

            zeroing_norm = self._zeroing_norm_process.report(
                zeroing_norm_state)

            zeroed_value, norm, was_zeroed = intrinsics.federated_map(
                zero_fn, (value, intrinsics.federated_broadcast(zeroing_norm)))

            new_zeroing_norm_state = self._zeroing_norm_process.next(
                zeroing_norm_state, norm)

            if weight is None:
                agg_output = inner_agg_next(agg_state, zeroed_value)
            else:
                agg_output = inner_agg_next(agg_state, zeroed_value, weight)

            zeroed_count_output = zeroed_count_agg_next(
                zeroed_count_state, was_zeroed)

            new_state = collections.OrderedDict(
                zeroing_norm=new_zeroing_norm_state,
                inner_agg=agg_output.state,
                zeroed_count_agg=zeroed_count_output.state)
            measurements = collections.OrderedDict(
                agg_process=agg_output.measurements,
                zeroing_norm=zeroing_norm,
                zeroed_count=zeroed_count_output.result)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=agg_output.result,
                measurements=intrinsics.federated_zip(measurements))
Esempio n. 27
0
    def test_clip_type_properties_weighted(self, value_type, weight_type):
        factory = _concat_mean()
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.to_type(weight_type)
        process = factory.create(value_type, weight_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # State comes from the inner MeanFactory.
        server_state_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        # Measurements come from the inner mean factory.
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(mean_value=(), mean_weight=()))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type),
                weight=computation_types.at_clients(weight_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
Esempio n. 28
0
    def test_type_properties(self, modulus, value_type, symmetric_range):
        factory_ = secure.SecureModularSumFactory(
            modulus=modulus, symmetric_range=symmetric_range)
        self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = factory_.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = expected_state_type

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
Esempio n. 29
0
 def next_comp(state, value, weight):
     return measured_process.MeasuredProcessOutput(
         state=intrinsics.federated_map(_add_one, state),
         result=intrinsics.federated_mean(value, weight),
         measurements=intrinsics.federated_zip(
             collections.OrderedDict(num_clients=intrinsics.federated_sum(
                 intrinsics.federated_value(1, placements.CLIENTS)))))
  def test_clip(self, clip_mechanism):
    clip_factory = clipping_factory.HistogramClippingSumFactory(
        clip_mechanism, 1)
    self.assertIsInstance(clip_factory, factory.UnweightedAggregationFactory)
    value_type = computation_types.to_type((tf.int32, (2,)))
    process = clip_factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.FederatedType(value_type,
                                                       placements.CLIENTS)
    result_value_type = computation_types.FederatedType(value_type,
                                                        placements.SERVER)
    expected_state_type = computation_types.at_server(())
    expected_measurements_type = computation_types.at_server(())

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))