Esempio n. 1
0
    def create_scalar_multiply_operator(
        self, operand_type: computation_types.Type,
        scalar_type: computation_types.TensorType
    ) -> local_computation_factory_base.ComputationProtoAndType:
        py_typecheck.check_type(operand_type, computation_types.Type)
        py_typecheck.check_type(scalar_type, computation_types.TensorType)
        if not type_analysis.is_structure_of_tensors(operand_type):
            raise ValueError(
                'Not a tensor or a structure of tensors: {}'.format(
                    str(operand_type)))

        operand_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type(
            operand_type)
        scalar_shape = _xla_tensor_shape_from_tff_tensor_type(scalar_type)
        num_operand_tensors = len(operand_shapes)
        builder = xla_client.XlaBuilder('comp')
        param = xla_client.ops.Parameter(
            builder, 0,
            xla_client.Shape.tuple_shape(operand_shapes + [scalar_shape]))
        scalar_ref = xla_client.ops.GetTupleElement(param, num_operand_tensors)
        result_tensors = []
        for idx in range(num_operand_tensors):
            result_tensors.append(
                xla_client.ops.Mul(xla_client.ops.GetTupleElement(param, idx),
                                   scalar_ref))
        xla_client.ops.Tuple(builder, result_tensors)
        xla_computation = builder.build()

        comp_type = computation_types.FunctionType(
            computation_types.StructType([(None, operand_type),
                                          (None, scalar_type)]), operand_type)
        comp_pb = xla_serialization.create_xla_tff_computation(
            xla_computation, list(range(num_operand_tensors + 1)), comp_type)
        return (comp_pb, comp_type)
def _apply_generic_op(op, arg):
    if not (arg.type_signature.is_federated()
            or type_analysis.is_structure_of_tensors(arg.type_signature)):
        # If there are federated elements nested in a struct, we need to zip these
        # together before passing to binary operator constructor.
        arg = building_block_factory.create_federated_zip(arg)
    return building_block_factory.apply_binary_operator_with_upcast(arg, op)
Esempio n. 3
0
def _check_value_type(value_type):
  """Check value_type meets documented criteria."""
  if not (value_type.is_tensor() or
          (value_type.is_struct_with_python() and
           type_analysis.is_structure_of_tensors(value_type))):
    raise TypeError('Expected `value_type` to be `TensorType` or '
                    '`StructWithPythonType` containing only `TensorType`. '
                    f'Found type: {repr(value_type)}')

  if not (type_analysis.is_structure_of_floats(value_type) or
          type_analysis.is_structure_of_integers(value_type)):
    raise TypeError('Component dtypes of `value_type` must be all integers or '
                    f'all floats. Found {value_type}.')
Esempio n. 4
0
def _validate_value_type_and_encoders(value_type, encoders, encoder_type):
  """Validates if `value_type` and `encoders` are compatible."""
  if isinstance(encoders, _ALLOWED_ENCODERS):
    # If `encoders` is not a container, then `value_type` should be an instance
    # of `tff.TensorType.`
    if not isinstance(value_type, computation_types.TensorType):
      raise ValueError(
          '`value_type` and `encoders` do not have the same structure.')

    _validate_encoder(encoders, value_type, encoder_type)
  else:
    # If `encoders` is a container, then `value_type` should be an instance of
    # `tff.StructType.`
    if not type_analysis.is_structure_of_tensors(value_type):
      raise TypeError('`value_type` is not compatible with the expected input '
                      'of the `encoders`.')
    value_tensorspecs = type_conversions.type_to_tf_tensor_specs(value_type)
    tf.nest.map_structure(lambda e, v: _validate_encoder(e, v, encoder_type),
                          encoders, value_tensorspecs)
Esempio n. 5
0
def _create_xla_binary_op_computation(type_spec, xla_binary_op_constructor):
    """Helper for constructing computations that implement binary operators.

  The constructed computation is of type `(<T,T> -> T)`, where `T` is the type
  of the operand (`type_spec`).

  Args:
    type_spec: The type of a single operand.
    xla_binary_op_constructor: A two-argument callable that constructs a binary
      xla op from tensor parameters (such as `xla_client.ops.Add` or similar).

  Returns:
    An instance of `local_computation_factory_base.ComputationProtoAndType`.

  Raises:
    ValueError: if the arguments are invalid.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if not type_analysis.is_structure_of_tensors(type_spec):
        raise ValueError('Not a tensor or a structure of tensors: {}'.format(
            str(type_spec)))

    tensor_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type(
        type_spec)
    num_tensors = len(tensor_shapes)
    builder = xla_client.XlaBuilder('comp')
    param = xla_client.ops.Parameter(
        builder, 0, xla_client.Shape.tuple_shape(tensor_shapes * 2))
    result_tensors = []
    for idx in range(num_tensors):
        result_tensors.append(
            xla_binary_op_constructor(
                xla_client.ops.GetTupleElement(param, idx),
                xla_client.ops.GetTupleElement(param, idx + num_tensors)))
    xla_client.ops.Tuple(builder, result_tensors)
    xla_computation = builder.build()

    comp_type = computation_types.FunctionType(
        computation_types.StructType([(None, type_spec)] * 2), type_spec)
    comp_pb = xla_serialization.create_xla_tff_computation(
        xla_computation, list(range(2 * num_tensors)), comp_type)
    return (comp_pb, comp_type)
Esempio n. 6
0
    def create_constant_from_scalar(
        self, value, type_spec: computation_types.Type
    ) -> local_computation_factory_base.ComputationProtoAndType:
        py_typecheck.check_type(type_spec, computation_types.Type)
        if not type_analysis.is_structure_of_tensors(type_spec):
            raise ValueError(
                'Not a tensor or a structure of tensors: {}'.format(
                    str(type_spec)))

        builder = xla_client.XlaBuilder('comp')

        # We maintain the convention that arguments are supplied as a tuple for the
        # sake of consistency and uniformity (see comments in `computation.proto`).
        # Since there are no arguments here, we create an empty tuple.
        xla_client.ops.Parameter(builder, 0,
                                 xla_client.shape_from_pyval(tuple()))

        def _constant_from_tensor(tensor_type):
            py_typecheck.check_type(tensor_type, computation_types.TensorType)
            numpy_value = np.full(shape=tensor_type.shape.dims,
                                  fill_value=value,
                                  dtype=tensor_type.dtype.as_numpy_dtype)
            return xla_client.ops.Constant(builder, numpy_value)

        if isinstance(type_spec, computation_types.TensorType):
            tensors = [_constant_from_tensor(type_spec)]
        else:
            tensors = [
                _constant_from_tensor(x) for x in structure.flatten(type_spec)
            ]

        # Likewise, results are always returned as a single tuple with results.
        # This is always a flat tuple; the nested TFF structure is defined by the
        # binding.
        xla_client.ops.Tuple(builder, tensors)
        xla_computation = builder.build()

        comp_type = computation_types.FunctionType(None, type_spec)
        comp_pb = xla_serialization.create_xla_tff_computation(
            xla_computation, [], comp_type)
        return (comp_pb, comp_type)
Esempio n. 7
0
    def create(self, value_type):
        # Checks value_type and compute client data dimension.
        if (value_type.is_struct_with_python()
                and type_analysis.is_structure_of_tensors(value_type)):
            num_elements_struct = type_conversions.structure_from_tensor_type_tree(
                lambda x: x.shape.num_elements(), value_type)
            self._client_dim = sum(tf.nest.flatten(num_elements_struct))
        elif value_type.is_tensor():
            self._client_dim = value_type.shape.num_elements()
        else:
            raise TypeError(
                'Expected `value_type` to be `TensorType` or '
                '`StructWithPythonType` containing only `TensorType`. '
                f'Found type: {repr(value_type)}')
        # Checks that all values are integers or floats.
        if not (type_analysis.is_structure_of_floats(value_type)
                or type_analysis.is_structure_of_integers(value_type)):
            raise TypeError(
                'Component dtypes of `value_type` must all be integers '
                f'or floats. Found {repr(value_type)}.')

        ddp_agg_process = self._build_aggregation_factory().create(value_type)
        init_fn = ddp_agg_process.initialize

        @federated_computation.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type))
        def next_fn(state, value):
            agg_output = ddp_agg_process.next(state, value)
            new_measurements = self._derive_measurements(
                agg_output.state, agg_output.measurements)
            new_state = agg_output.state
            if self._auto_l2_clip:
                new_state = self._autotune_component_states(agg_output.state)

            return measured_process.MeasuredProcessOutput(
                state=new_state,
                result=agg_output.result,
                measurements=new_measurements)

        return aggregation_process.AggregationProcess(init_fn, next_fn)
Esempio n. 8
0
    def create(
        self, value_type: factory.ValueType
    ) -> aggregation_process.AggregationProcess:
        # Checks value_type and compute client data dimension.
        if (value_type.is_struct()
                and type_analysis.is_structure_of_tensors(value_type)):
            num_elements_struct = type_conversions.structure_from_tensor_type_tree(
                lambda x: x.shape.num_elements(), value_type)
            client_dim = sum(tf.nest.flatten(num_elements_struct))
        elif value_type.is_tensor():
            client_dim = value_type.shape.num_elements()
        else:
            raise TypeError('Expected `value_type` to be `TensorType` or '
                            '`StructType` containing only `TensorType`. '
                            f'Found type: {repr(value_type)}')
        # Checks that all values are integers.
        if not type_analysis.is_structure_of_integers(value_type):
            raise TypeError(
                'Component dtypes of `value_type` must all be integers. '
                f'Found {repr(value_type)}.')
        # Checks that we have enough elements to estimate standard deviation.
        if self._estimate_stddev:
            if client_dim <= 1:
                raise ValueError(
                    'The stddev estimation procedure expects more than '
                    '1 element from `value_type`. Found `value_type` of '
                    f'{value_type} with {client_dim} elements.')
            elif client_dim <= 100:
                warnings.warn(
                    f'`value_type` has only {client_dim} elements. The '
                    'estimated standard deviation may be noisy. Consider '
                    'setting `estimate_stddev` to True only if the input '
                    'tensor/structure have more than 100 elements.')

        inner_agg_process = self._inner_agg_factory.create(value_type)
        init_fn = inner_agg_process.initialize
        next_fn = self._create_next_fn(inner_agg_process.next,
                                       init_fn.type_signature.result,
                                       value_type)
        return aggregation_process.AggregationProcess(init_fn, next_fn)
Esempio n. 9
0
def from_keras_model(
    keras_model: tf.keras.Model,
    loss: Loss,
    input_spec,
    loss_weights: Optional[List[float]] = None,
    metrics: Optional[List[tf.keras.metrics.Metric]] = None
) -> model_lib.Model:
    """Builds a `tff.learning.Model` from a `tf.keras.Model`.

  The `tff.learning.Model` returned by this function uses `keras_model` for
  its forward pass and autodifferentiation steps.

  Notice that since TFF couples the `tf.keras.Model` and `loss`,
  TFF needs a slightly different notion of "fully specified type" than
  pure Keras does. That is, the model `M` takes inputs of type `x` and
  produces predictions of type `p`; the loss function `L` takes inputs of type
  `<p, y>` (where `y` is the ground truth label type) and produces a scalar.
  Therefore in order to fully specify the type signatures for computations in
  which the generated `tff.learning.Model` will appear, TFF needs the type `y`
  in addition to the type `x`.

  Note: This function does not currently accept subclassed `tf.keras.Models`,
  as it makes assumptions about presence of certain attributes which are
  guaranteed to exist through the functional or Sequential API but are
  not necessarily present for subclassed models.

  Args:
    keras_model: A `tf.keras.Model` object that is not compiled.
    loss: A single `tf.keras.losses.Loss` or a list of losses-per-output. If a
      single loss is provided, then all model output (as well as all prediction
      information) is passed to the loss; this includes situations of multiple
      model outputs and/or predictions. If multiple losses are provided as a
      list, then each loss is expected to correspond to a model output; the
      model will attempt to minimize the sum of all individual losses
      (optionally weighted using the `loss_weights` argument).
    input_spec: A structure of `tf.TensorSpec`s or `tff.Type` specifying the
      type of arguments the model expects. If `input_spec` is a `tff.Type`, its
      leaf nodes must be `TensorType`s. Note that `input_spec` must be a
      compound structure of two elements, specifying both the data fed into the
      model (x) to generate predictions as well as the expected type of the
      ground truth (y). If provided as a list, it must be in the order [x, y].
      If provided as a dictionary, the keys must explicitly be named `'x'` and
      `'y'`.
    loss_weights: (Optional) A list of Python floats used to weight the loss
      contribution of each model output (when providing a list of losses for the
      `loss` argument).
    metrics: (Optional) a list of `tf.keras.metrics.Metric` objects.

  Returns:
    A `tff.learning.Model` object.

  Raises:
    TypeError: If `keras_model` is not an instance of `tf.keras.Model`, if
      `loss` is not an instance of `tf.keras.losses.Loss` nor a list of
      instances of `tf.keras.losses.Loss`, if `input_spec` is a `tff.Type` but
      the leaf nodes are not `tff.TensorType`s, if `loss_weight` is provided but
      is not a list of floats, or if `metrics` is provided but is not a list of
      instances of `tf.keras.metrics.Metric`.
    ValueError: If `keras_model` was compiled, if `loss` is a list of unequal
      length to the number of outputs of `keras_model`, if `loss_weights` is
      specified but `loss` is not a list, if `input_spec` does not contain
      exactly two elements, or if `input_spec` is a dictionary and does not
      contain keys `'x'` and `'y'`.
  """
    # Validate `keras_model`
    py_typecheck.check_type(keras_model, tf.keras.Model)
    if keras_model._is_compiled:  # pylint: disable=protected-access
        raise ValueError('`keras_model` must not be compiled')

    # Validate and normalize `loss` and `loss_weights`
    if not isinstance(loss, list):
        py_typecheck.check_type(loss, tf.keras.losses.Loss)
        if loss_weights is not None:
            raise ValueError(
                '`loss_weights` cannot be used if `loss` is not a list.')
        loss = [loss]
        loss_weights = [1.0]
    else:
        if len(loss) != len(keras_model.outputs):
            raise ValueError(
                'If a loss list is provided, `keras_model` must have '
                'equal number of outputs to the losses.\nloss: {}\nof '
                'length: {}.\noutputs: {}\nof length: {}.'.format(
                    loss, len(loss), keras_model.outputs,
                    len(keras_model.outputs)))
        for loss_fn in loss:
            py_typecheck.check_type(loss_fn, tf.keras.losses.Loss)

        if loss_weights is None:
            loss_weights = [1.0] * len(loss)
        else:
            if len(loss) != len(loss_weights):
                raise ValueError(
                    '`keras_model` must have equal number of losses and loss_weights.'
                    '\nloss: {}\nof length: {}.'
                    '\nloss_weights: {}\nof length: {}.'.format(
                        loss, len(loss), loss_weights, len(loss_weights)))
            for loss_weight in loss_weights:
                py_typecheck.check_type(loss_weight, float)

    if len(input_spec) != 2:
        raise ValueError(
            'The top-level structure in `input_spec` must contain '
            'exactly two top-level elements, as it must specify type '
            'information for both inputs to and predictions from the '
            'model. You passed input spec {}.'.format(input_spec))
    if isinstance(input_spec, computation_types.Type):
        if not type_analysis.is_structure_of_tensors(input_spec):
            raise TypeError(
                'Expected a `tff.Type` with all the leaf nodes being '
                '`tff.TensorType`s, found an input spec {}.'.format(
                    input_spec))
        input_spec = type_conversions.structure_from_tensor_type_tree(
            lambda tensor_type: tf.TensorSpec(tensor_type.shape, tensor_type.
                                              dtype), input_spec)
    else:
        tf.nest.map_structure(
            lambda s: py_typecheck.check_type(s, tf.TensorSpec,
                                              'input spec member'), input_spec)
    if isinstance(input_spec, collections.abc.Mapping):
        if 'x' not in input_spec:
            raise ValueError(
                'The `input_spec` is a collections.abc.Mapping (e.g., a dict), so it '
                'must contain an entry with key `\'x\'`, representing the input(s) '
                'to the Keras model.')
        if 'y' not in input_spec:
            raise ValueError(
                'The `input_spec` is a collections.abc.Mapping (e.g., a dict), so it '
                'must contain an entry with key `\'y\'`, representing the label(s) '
                'to be used in the Keras loss(es).')

    if metrics is None:
        metrics = []
    else:
        py_typecheck.check_type(metrics, list)
        for metric in metrics:
            py_typecheck.check_type(metric, tf.keras.metrics.Metric)

    return _KerasModel(keras_model,
                       input_spec=input_spec,
                       loss_fns=loss,
                       loss_weights=loss_weights,
                       metrics=metrics)
Esempio n. 10
0
    def create(self, value_type):
        # Validate input args and value_type and parse out the TF dtypes.
        if value_type.is_tensor():
            tf_dtype = value_type.dtype
        elif (value_type.is_struct_with_python()
              and type_analysis.is_structure_of_tensors(value_type)):
            if self._prior_norm_bound:
                raise TypeError(
                    'If `prior_norm_bound` is specified, `value_type` must '
                    f'be `TensorType`. Found type: {repr(value_type)}.')
            tf_dtype = type_conversions.structure_from_tensor_type_tree(
                lambda x: x.dtype, value_type)
        else:
            raise TypeError(
                'Expected `value_type` to be `TensorType` or '
                '`StructWithPythonType` containing only `TensorType`. '
                f'Found type: {repr(value_type)}')

        # Check that all values are floats.
        if not type_analysis.is_structure_of_floats(value_type):
            raise TypeError(
                'Component dtypes of `value_type` must all be floats. '
                f'Found {repr(value_type)}.')

        discretize_fn = _build_discretize_fn(value_type, self._stochastic,
                                             self._beta)

        @tensorflow_computation.tf_computation(
            discretize_fn.type_signature.result, tf.float32)
        def undiscretize_fn(value, scale_factor):
            return _undiscretize_struct(value, scale_factor, tf_dtype)

        inner_value_type = discretize_fn.type_signature.result
        inner_agg_process = self._inner_agg_factory.create(inner_value_type)

        @federated_computation.federated_computation()
        def init_fn():
            state = collections.OrderedDict(
                scale_factor=intrinsics.federated_value(
                    self._scale_factor, placements.SERVER),
                prior_norm_bound=intrinsics.federated_value(
                    self._prior_norm_bound, placements.SERVER),
                inner_agg_process=inner_agg_process.initialize())
            return intrinsics.federated_zip(state)

        @federated_computation.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type))
        def next_fn(state, value):
            server_scale_factor = state['scale_factor']
            client_scale_factor = intrinsics.federated_broadcast(
                server_scale_factor)
            server_prior_norm_bound = state['prior_norm_bound']
            prior_norm_bound = intrinsics.federated_broadcast(
                server_prior_norm_bound)

            discretized_value = intrinsics.federated_map(
                discretize_fn, (value, client_scale_factor, prior_norm_bound))

            inner_state = state['inner_agg_process']
            inner_agg_output = inner_agg_process.next(inner_state,
                                                      discretized_value)

            undiscretized_agg_value = intrinsics.federated_map(
                undiscretize_fn,
                (inner_agg_output.result, server_scale_factor))

            new_state = collections.OrderedDict(
                scale_factor=server_scale_factor,
                prior_norm_bound=server_prior_norm_bound,
                inner_agg_process=inner_agg_output.state)
            measurements = collections.OrderedDict(
                discretize=inner_agg_output.measurements)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=undiscretized_agg_value,
                measurements=intrinsics.federated_zip(measurements))

        return aggregation_process.AggregationProcess(init_fn, next_fn)
Esempio n. 11
0
  def create(
      self,
      value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
    # Validate input args and value_type and parse out the TF dtypes.
    if value_type.is_tensor():
      tf_dtype = value_type.dtype
    elif (value_type.is_struct_with_python() and
          type_analysis.is_structure_of_tensors(value_type)):
      tf_dtype = type_conversions.structure_from_tensor_type_tree(
          lambda x: x.dtype, value_type)
    else:
      raise TypeError('Expected `value_type` to be `TensorType` or '
                      '`StructWithPythonType` containing only `TensorType`. '
                      f'Found type: {repr(value_type)}')

    # Check that all values are floats.
    if not type_analysis.is_structure_of_floats(value_type):
      raise TypeError('Component dtypes of `value_type` must all be floats. '
                      f'Found {repr(value_type)}.')

    if self._distortion_aggregation_factory is not None:
      distortion_aggregation_process = self._distortion_aggregation_factory.create(
          computation_types.to_type(tf.float32))

    @tensorflow_computation.tf_computation(value_type, tf.float32)
    def discretize_fn(value, step_size):
      return _discretize_struct(value, step_size)

    @tensorflow_computation.tf_computation(discretize_fn.type_signature.result,
                                           tf.float32)
    def undiscretize_fn(value, step_size):
      return _undiscretize_struct(value, step_size, tf_dtype)

    @tensorflow_computation.tf_computation(value_type, tf.float32)
    def distortion_measurement_fn(value, step_size):
      reconstructed_value = undiscretize_fn(
          discretize_fn(value, step_size), step_size)
      err = tf.nest.map_structure(tf.subtract, reconstructed_value, value)
      squared_err = tf.nest.map_structure(tf.square, err)
      flat_squared_errs = [
          tf.cast(tf.reshape(t, [-1]), tf.float32)
          for t in tf.nest.flatten(squared_err)
      ]
      all_squared_errs = tf.concat(flat_squared_errs, axis=0)
      mean_squared_err = tf.reduce_mean(all_squared_errs)
      return mean_squared_err

    inner_agg_process = self._inner_agg_factory.create(
        discretize_fn.type_signature.result)

    @federated_computation.federated_computation()
    def init_fn():
      state = collections.OrderedDict(
          step_size=intrinsics.federated_value(self._step_size,
                                               placements.SERVER),
          inner_agg_process=inner_agg_process.initialize())
      return intrinsics.federated_zip(state)

    @federated_computation.federated_computation(
        init_fn.type_signature.result, computation_types.at_clients(value_type))
    def next_fn(state, value):
      server_step_size = state['step_size']
      client_step_size = intrinsics.federated_broadcast(server_step_size)

      discretized_value = intrinsics.federated_map(discretize_fn,
                                                   (value, client_step_size))

      inner_state = state['inner_agg_process']
      inner_agg_output = inner_agg_process.next(inner_state, discretized_value)

      undiscretized_agg_value = intrinsics.federated_map(
          undiscretize_fn, (inner_agg_output.result, server_step_size))

      new_state = collections.OrderedDict(
          step_size=server_step_size, inner_agg_process=inner_agg_output.state)
      measurements = collections.OrderedDict(
          deterministic_discretization=inner_agg_output.measurements)

      if self._distortion_aggregation_factory is not None:
        distortions = intrinsics.federated_map(distortion_measurement_fn,
                                               (value, client_step_size))
        aggregate_distortion = distortion_aggregation_process.next(
            distortion_aggregation_process.initialize(), distortions).result
        measurements['distortion'] = aggregate_distortion

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(new_state),
          result=undiscretized_agg_value,
          measurements=intrinsics.federated_zip(measurements))

    return aggregation_process.AggregationProcess(init_fn, next_fn)