예제 #1
0
def _check_value_type(value_type):
    type_args = typing.get_args(factory.ValueType)
    py_typecheck.check_type(value_type, type_args)
    if not type_analysis.is_structure_of_floats(value_type):
        raise TypeError(
            f'All values in provided value_type must be of floating '
            f'dtype. Provided value_type: {value_type}')
예제 #2
0
def _check_value_type(value_type):
  """Check value_type meets documented criteria."""
  if not (value_type.is_tensor() or
          (value_type.is_struct_with_python() and
           type_analysis.is_structure_of_tensors(value_type))):
    raise TypeError('Expected `value_type` to be `TensorType` or '
                    '`StructWithPythonType` containing only `TensorType`. '
                    f'Found type: {repr(value_type)}')

  if not (type_analysis.is_structure_of_floats(value_type) or
          type_analysis.is_structure_of_integers(value_type)):
    raise TypeError('Component dtypes of `value_type` must be all integers or '
                    f'all floats. Found {value_type}.')
예제 #3
0
  def _check_value_type_compatible_with_config_mode(self, value_type):
    py_typecheck.check_type(value_type, factory.ValueType.__args__)

    if self._config_mode == _Config.INT:
      if not type_analysis.is_structure_of_integers(value_type):
        raise TypeError(
            f'The `SecureSumFactory` was configured to work with integer '
            f'dtypes. All values in provided `value_type` hence must be of '
            f'integer dtype. \nProvided value_type: {value_type}')
    elif self._config_mode == _Config.FLOAT:
      if not type_analysis.is_structure_of_floats(value_type):
        raise TypeError(
            f'The `SecureSumFactory` was configured to work with floating '
            f'point dtypes. All values in provided `value_type` hence must be '
            f'of floating point dtype. \nProvided value_type: {value_type}')
    else:
      raise ValueError(f'Unexpected internal config type: {self._config_mode}')
예제 #4
0
    def create(self, value_type):
        # Checks value_type and compute client data dimension.
        if (value_type.is_struct_with_python()
                and type_analysis.is_structure_of_tensors(value_type)):
            num_elements_struct = type_conversions.structure_from_tensor_type_tree(
                lambda x: x.shape.num_elements(), value_type)
            self._client_dim = sum(tf.nest.flatten(num_elements_struct))
        elif value_type.is_tensor():
            self._client_dim = value_type.shape.num_elements()
        else:
            raise TypeError(
                'Expected `value_type` to be `TensorType` or '
                '`StructWithPythonType` containing only `TensorType`. '
                f'Found type: {repr(value_type)}')
        # Checks that all values are integers or floats.
        if not (type_analysis.is_structure_of_floats(value_type)
                or type_analysis.is_structure_of_integers(value_type)):
            raise TypeError(
                'Component dtypes of `value_type` must all be integers '
                f'or floats. Found {repr(value_type)}.')

        ddp_agg_process = self._build_aggregation_factory().create(value_type)
        init_fn = ddp_agg_process.initialize

        @federated_computation.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type))
        def next_fn(state, value):
            agg_output = ddp_agg_process.next(state, value)
            new_measurements = self._derive_measurements(
                agg_output.state, agg_output.measurements)
            new_state = agg_output.state
            if self._auto_l2_clip:
                new_state = self._autotune_component_states(agg_output.state)

            return measured_process.MeasuredProcessOutput(
                state=new_state,
                result=agg_output.result,
                measurements=new_measurements)

        return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #5
0
    def _check_value_type_compatible_with_config_mode(self, value_type):
        type_args = typing.get_args(factory.ValueType)
        py_typecheck.check_type(value_type, type_args)
        if not _is_structure_of_single_dtype(value_type):
            raise TypeError(
                f'Expected a type which is a structure containing the same dtypes, '
                f'found {value_type}.')

        if self._config_mode == _Config.INT:
            if not type_analysis.is_structure_of_integers(value_type):
                raise TypeError(
                    f'The `SecureSumFactory` was configured to work with integer '
                    f'dtypes. All values in provided `value_type` hence must be of '
                    f'integer dtype. \nProvided value_type: {value_type}')
        elif self._config_mode == _Config.FLOAT:
            if not type_analysis.is_structure_of_floats(value_type):
                raise TypeError(
                    f'The `SecureSumFactory` was configured to work with floating '
                    f'point dtypes. All values in provided `value_type` hence must be '
                    f'of floating point dtype. \nProvided value_type: {value_type}'
                )
        else:
            raise ValueError(
                f'Unexpected internal config type: {self._config_mode}')
예제 #6
0
    def create(self, value_type):
        # Validate input args and value_type and parse out the TF dtypes.
        if value_type.is_tensor():
            tf_dtype = value_type.dtype
        elif (value_type.is_struct_with_python()
              and type_analysis.is_structure_of_tensors(value_type)):
            if self._prior_norm_bound:
                raise TypeError(
                    'If `prior_norm_bound` is specified, `value_type` must '
                    f'be `TensorType`. Found type: {repr(value_type)}.')
            tf_dtype = type_conversions.structure_from_tensor_type_tree(
                lambda x: x.dtype, value_type)
        else:
            raise TypeError(
                'Expected `value_type` to be `TensorType` or '
                '`StructWithPythonType` containing only `TensorType`. '
                f'Found type: {repr(value_type)}')

        # Check that all values are floats.
        if not type_analysis.is_structure_of_floats(value_type):
            raise TypeError(
                'Component dtypes of `value_type` must all be floats. '
                f'Found {repr(value_type)}.')

        discretize_fn = _build_discretize_fn(value_type, self._stochastic,
                                             self._beta)

        @tensorflow_computation.tf_computation(
            discretize_fn.type_signature.result, tf.float32)
        def undiscretize_fn(value, scale_factor):
            return _undiscretize_struct(value, scale_factor, tf_dtype)

        inner_value_type = discretize_fn.type_signature.result
        inner_agg_process = self._inner_agg_factory.create(inner_value_type)

        @federated_computation.federated_computation()
        def init_fn():
            state = collections.OrderedDict(
                scale_factor=intrinsics.federated_value(
                    self._scale_factor, placements.SERVER),
                prior_norm_bound=intrinsics.federated_value(
                    self._prior_norm_bound, placements.SERVER),
                inner_agg_process=inner_agg_process.initialize())
            return intrinsics.federated_zip(state)

        @federated_computation.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type))
        def next_fn(state, value):
            server_scale_factor = state['scale_factor']
            client_scale_factor = intrinsics.federated_broadcast(
                server_scale_factor)
            server_prior_norm_bound = state['prior_norm_bound']
            prior_norm_bound = intrinsics.federated_broadcast(
                server_prior_norm_bound)

            discretized_value = intrinsics.federated_map(
                discretize_fn, (value, client_scale_factor, prior_norm_bound))

            inner_state = state['inner_agg_process']
            inner_agg_output = inner_agg_process.next(inner_state,
                                                      discretized_value)

            undiscretized_agg_value = intrinsics.federated_map(
                undiscretize_fn,
                (inner_agg_output.result, server_scale_factor))

            new_state = collections.OrderedDict(
                scale_factor=server_scale_factor,
                prior_norm_bound=server_prior_norm_bound,
                inner_agg_process=inner_agg_output.state)
            measurements = collections.OrderedDict(
                discretize=inner_agg_output.measurements)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=undiscretized_agg_value,
                measurements=intrinsics.federated_zip(measurements))

        return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #7
0
  def create(
      self,
      value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
    # Validate input args and value_type and parse out the TF dtypes.
    if value_type.is_tensor():
      tf_dtype = value_type.dtype
    elif (value_type.is_struct_with_python() and
          type_analysis.is_structure_of_tensors(value_type)):
      tf_dtype = type_conversions.structure_from_tensor_type_tree(
          lambda x: x.dtype, value_type)
    else:
      raise TypeError('Expected `value_type` to be `TensorType` or '
                      '`StructWithPythonType` containing only `TensorType`. '
                      f'Found type: {repr(value_type)}')

    # Check that all values are floats.
    if not type_analysis.is_structure_of_floats(value_type):
      raise TypeError('Component dtypes of `value_type` must all be floats. '
                      f'Found {repr(value_type)}.')

    if self._distortion_aggregation_factory is not None:
      distortion_aggregation_process = self._distortion_aggregation_factory.create(
          computation_types.to_type(tf.float32))

    @tensorflow_computation.tf_computation(value_type, tf.float32)
    def discretize_fn(value, step_size):
      return _discretize_struct(value, step_size)

    @tensorflow_computation.tf_computation(discretize_fn.type_signature.result,
                                           tf.float32)
    def undiscretize_fn(value, step_size):
      return _undiscretize_struct(value, step_size, tf_dtype)

    @tensorflow_computation.tf_computation(value_type, tf.float32)
    def distortion_measurement_fn(value, step_size):
      reconstructed_value = undiscretize_fn(
          discretize_fn(value, step_size), step_size)
      err = tf.nest.map_structure(tf.subtract, reconstructed_value, value)
      squared_err = tf.nest.map_structure(tf.square, err)
      flat_squared_errs = [
          tf.cast(tf.reshape(t, [-1]), tf.float32)
          for t in tf.nest.flatten(squared_err)
      ]
      all_squared_errs = tf.concat(flat_squared_errs, axis=0)
      mean_squared_err = tf.reduce_mean(all_squared_errs)
      return mean_squared_err

    inner_agg_process = self._inner_agg_factory.create(
        discretize_fn.type_signature.result)

    @federated_computation.federated_computation()
    def init_fn():
      state = collections.OrderedDict(
          step_size=intrinsics.federated_value(self._step_size,
                                               placements.SERVER),
          inner_agg_process=inner_agg_process.initialize())
      return intrinsics.federated_zip(state)

    @federated_computation.federated_computation(
        init_fn.type_signature.result, computation_types.at_clients(value_type))
    def next_fn(state, value):
      server_step_size = state['step_size']
      client_step_size = intrinsics.federated_broadcast(server_step_size)

      discretized_value = intrinsics.federated_map(discretize_fn,
                                                   (value, client_step_size))

      inner_state = state['inner_agg_process']
      inner_agg_output = inner_agg_process.next(inner_state, discretized_value)

      undiscretized_agg_value = intrinsics.federated_map(
          undiscretize_fn, (inner_agg_output.result, server_step_size))

      new_state = collections.OrderedDict(
          step_size=server_step_size, inner_agg_process=inner_agg_output.state)
      measurements = collections.OrderedDict(
          deterministic_discretization=inner_agg_output.measurements)

      if self._distortion_aggregation_factory is not None:
        distortions = intrinsics.federated_map(distortion_measurement_fn,
                                               (value, client_step_size))
        aggregate_distortion = distortion_aggregation_process.next(
            distortion_aggregation_process.initialize(), distortions).result
        measurements['distortion'] = aggregate_distortion

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(new_state),
          result=undiscretized_agg_value,
          measurements=intrinsics.federated_zip(measurements))

    return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #8
0
 def test_returns_false(self, type_spec):
     self.assertFalse(type_analysis.is_structure_of_floats(type_spec))