예제 #1
0
    def __init__(self, keras_model: tf.keras.Model, input_spec,
                 loss_fns: List[tf.keras.losses.Loss],
                 loss_weights: List[float],
                 metrics: List[tf.keras.metrics.Metric]):
        self._keras_model = keras_model
        self._input_spec = input_spec
        self._loss_fns = loss_fns
        self._loss_weights = loss_weights
        self._metrics = metrics

        # This is defined here so that it closes over the `loss_fn`.
        class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
            """A `tf.keras.metrics.Metric` wrapper for the loss function."""
            def __init__(self, name='loss', dtype=tf.float32):
                super().__init__(name, dtype)
                self._loss_fns = loss_fns
                self._loss_weights = loss_weights

            def update_state(self, y_true, y_pred, sample_weight=None):
                if len(self._loss_fns) == 1:
                    batch_size = tf.shape(y_pred)[0]
                    batch_loss = self._loss_fns[0](y_true, y_pred)
                else:
                    batch_size = tf.shape(y_pred[0])[0]
                    batch_loss = tf.zeros(())
                    for i in range(len(self._loss_fns)):
                        batch_loss += self._loss_weights[i] * self._loss_fns[
                            i](y_true[i], y_pred[i])

                return super().update_state(batch_loss, batch_size)

        self._loss_metric = _WeightedMeanLossMetric()

        metric_variable_type_dict = tf.nest.map_structure(
            tf.TensorSpec.from_tensor, self.report_local_outputs())
        federated_local_outputs_type = tff.FederatedType(
            metric_variable_type_dict, tff.CLIENTS)

        def federated_output(local_outputs):
            return federated_aggregate_keras_metric(self.get_metrics(),
                                                    local_outputs)

        self._federated_output_computation = tff.federated_computation(
            federated_output, federated_local_outputs_type)
예제 #2
0
  def __init__(self, inner_model, dummy_batch, loss_fn, metrics):
    # TODO(b/124477598): the following set_session() should be removed in the
    # future. This is a workaround for Keras' caching sessions in a way that
    # isn't compatible with TFF. This is already fixed in TF master, but not as
    # of v1.13.1.
    #
    # We do not use .clear_session() because it blows away the graph stack by
    # resetting the default graph.
    tf.keras.backend.set_session(None)

    if hasattr(dummy_batch, '_asdict'):
      dummy_batch = dummy_batch._asdict()
    # Convert input to tensors, possibly from nested lists that need to be
    # converted to a single top-level tensor.
    dummy_tensors = collections.OrderedDict([
        (k, tf.convert_to_tensor_or_sparse_tensor(v))
        for k, v in six.iteritems(dummy_batch)
    ])
    # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized
    # variables until they are called on input. We forced that here.
    inner_model(dummy_tensors['x'])

    def _tensor_spec_with_undefined_batch_dim(tensor):
      # Remove the batch dimension and leave it unspecified.
      spec = tf.TensorSpec(
          shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype)
      return spec

    self._input_spec = nest.map_structure(_tensor_spec_with_undefined_batch_dim,
                                          dummy_tensors)

    self._keras_model = inner_model
    self._loss_fn = loss_fn
    self._metrics = metrics if metrics is not None else []

    # This is defined here so that it closes over the `loss_fn`.
    class _WeightedMeanLossMetric(keras_metrics.Metric):
      """A `tf.keras.metrics.Metric` wrapper for the loss function."""

      def __init__(self, name='loss', dtype=tf.float32):
        super(_WeightedMeanLossMetric, self).__init__(name, dtype)
        self._total_loss = self.add_weight('total_loss', initializer='zeros')
        self._total_weight = self.add_weight(
            'total_weight', initializer='zeros')
        self._loss_fn = loss_fn

      def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, self._dtype)
        y_pred = tf.cast(y_pred, self._dtype)

        # _loss_fn is expected to return the scalar mean loss, so we multiply by
        # the batch_size to get back to total loss.
        batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
        batch_total_loss = self._loss_fn(y_true, y_pred) * batch_size

        op = self._total_loss.assign_add(batch_total_loss)
        with tf.control_dependencies([op]):
          return self._total_weight.assign_add(batch_size)

      def result(self):
        return tf.div_no_nan(self._total_loss, self._total_weight)

    self._loss_metric = _WeightedMeanLossMetric()

    metric_variable_type_dict = nest.map_structure(tf.TensorSpec.from_tensor,
                                                   self.report_local_outputs())
    federated_local_outputs_type = tff.FederatedType(
        metric_variable_type_dict, tff.CLIENTS, all_equal=False)

    def federated_output(local_outputs):
      results = collections.OrderedDict()
      for metric, variables in zip(self.get_metrics(), local_outputs):
        results[metric.name] = federated_aggregate_keras_metric(
            type(metric), metric.get_config(), variables)
      return results

    self._federated_output_computation = tff.federated_computation(
        federated_output, federated_local_outputs_type)

    # Keras creates variables that are not added to any collection, making it
    # impossible for TFF to extract them and create the appropriate initializer
    # before call a tff.Computation. Here we store them in a TFF specific
    # collection so that they can be retrieved later.
    # TODO(b/122081673): this likely goes away in TF2.0
    for variable in itertools.chain(self.trainable_variables,
                                    self.non_trainable_variables,
                                    self.local_variables):
      tf.add_to_collection(graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE,
                           variable)
예제 #3
0
    def __init__(self,
                 inner_model,
                 dummy_batch,
                 loss_fns,
                 loss_weights=None,
                 metrics=None):

        # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized
        # variables until they are called on input. We forced that here.
        if isinstance(dummy_batch, collections.Mapping):
            inner_model(dummy_batch['x'])
        else:
            inner_model(dummy_batch[0])

        def _tensor_spec_with_undefined_batch_dim(tensor):
            # Remove the batch dimension and leave it unspecified.
            spec = tf.TensorSpec(shape=[None] + tensor.shape.dims[1:],
                                 dtype=tensor.dtype)
            return spec

        self._input_spec = tf.nest.map_structure(
            _tensor_spec_with_undefined_batch_dim, dummy_batch)

        self._keras_model = inner_model
        self._loss_fns = loss_fns

        if isinstance(loss_weights, collections.Mapping):
            self._loss_weights = []
            for name in inner_model.output_names:
                if name not in loss_weights:
                    raise KeyError(
                        'Output missing from loss_weights dictionary'
                        '\nloss_weights: {}\noutputs: {}'.format(
                            list(loss_weights.keys()),
                            inner_model.output_names))
                else:
                    self._loss_weights.append(loss_weights[name])
        else:
            if loss_weights is None:
                self._loss_weights = [1.0 for _ in range(len(loss_fns))]
            else:
                self._loss_weights = loss_weights

        loss_weights = self._loss_weights
        self._metrics = metrics if metrics is not None else []

        # This is defined here so that it closes over the `loss_fn`.
        class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
            """A `tf.keras.metrics.Metric` wrapper for the loss function."""
            def __init__(self, name='loss', dtype=tf.float32):
                super(_WeightedMeanLossMetric, self).__init__(name, dtype)
                self._loss_fns = loss_fns
                self._loss_weights = loss_weights

            def update_state(self, y_true, y_pred, sample_weight=None):
                if len(self._loss_fns) == 1:
                    batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
                    y_true = tf.cast(y_true, self._dtype)
                    y_pred = tf.cast(y_pred, self._dtype)
                    batch_loss = self._loss_fns[0](y_true, y_pred)

                else:
                    batch_loss = tf.zeros(())
                    for i in range(len(self._loss_fns)):
                        y_t = tf.cast(y_true[i], self._dtype)
                        y_p = tf.cast(y_pred[i], self._dtype)
                        batch_loss += self._loss_weights[i] * self._loss_fns[
                            i](y_t, y_p)

                    batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype)

                return super(_WeightedMeanLossMetric,
                             self).update_state(batch_loss, batch_size)

        class _TrainingTimeHistory(tf.keras.metrics.Sum):
            def update_state(self, y_true, y_pred, sample_weight=None):
                pass

            def log_time(self, time_value):
                return super(_TrainingTimeHistory,
                             self).update_state(values=time_value)

        self._loss_metric = _WeightedMeanLossMetric()
        self._training_timing = _TrainingTimeHistory(name='training_time_sec')

        metric_variable_type_dict = tf.nest.map_structure(
            tf.TensorSpec.from_tensor, self.report_local_outputs())
        federated_local_outputs_type = tff.FederatedType(
            metric_variable_type_dict, tff.CLIENTS)

        def federated_output(local_outputs):
            results = collections.OrderedDict()
            for metric, variables in zip(self.get_metrics(), local_outputs):
                results[metric.name] = federated_aggregate_keras_metric(
                    type(metric), metric.get_config(), variables)
            return results

        self._federated_output_computation = tff.federated_computation(
            federated_output, federated_local_outputs_type)

        # Keras creates variables that are not added to any collection, making it
        # impossible for TFF to extract them and create the appropriate initializer
        # before call a tff.Computation. Here we store them in a TFF specific
        # collection so that they can be retrieved later.
        # TODO(b/122081673): this likely goes away in TF2.0
        for variable in itertools.chain(self.trainable_variables,
                                        self.non_trainable_variables,
                                        self.local_variables):
            tf.compat.v1.add_to_collection(
                graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
 def federated_output_computation(self):
   return tff.federated_computation(
       lambda metrics: {'num_over': tff.federated_sum(metrics.num_over)})
예제 #5
0
  def __init__(self, inner_model, dummy_batch, loss_fn, metrics):

    # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized
    # variables until they are called on input. We forced that here.
    inner_model(dummy_batch['x'])

    def _tensor_spec_with_undefined_batch_dim(tensor):
      # Remove the batch dimension and leave it unspecified.
      spec = tf.TensorSpec(
          shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype)
      return spec

    self._input_spec = tf.nest.map_structure(
        _tensor_spec_with_undefined_batch_dim, dummy_batch)

    self._keras_model = inner_model
    self._loss_fn = loss_fn
    self._metrics = metrics if metrics is not None else []

    # This is defined here so that it closes over the `loss_fn`.
    class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
      """A `tf.keras.metrics.Metric` wrapper for the loss function."""

      def __init__(self, name='loss', dtype=tf.float32):
        super(_WeightedMeanLossMetric, self).__init__(name, dtype)
        self._loss_fn = loss_fn

      def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, self._dtype)
        y_pred = tf.cast(y_pred, self._dtype)

        batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
        batch_loss = self._loss_fn(y_true, y_pred)

        return super(_WeightedMeanLossMetric,
                     self).update_state(batch_loss, batch_size)

    self._loss_metric = _WeightedMeanLossMetric()

    metric_variable_type_dict = tf.nest.map_structure(
        tf.TensorSpec.from_tensor, self.report_local_outputs())
    federated_local_outputs_type = tff.FederatedType(metric_variable_type_dict,
                                                     tff.CLIENTS)

    def federated_output(local_outputs):
      results = collections.OrderedDict()
      for metric, variables in zip(self.get_metrics(), local_outputs):
        results[metric.name] = federated_aggregate_keras_metric(
            type(metric), metric.get_config(), variables)
      return results

    self._federated_output_computation = tff.federated_computation(
        federated_output, federated_local_outputs_type)

    # Keras creates variables that are not added to any collection, making it
    # impossible for TFF to extract them and create the appropriate initializer
    # before call a tff.Computation. Here we store them in a TFF specific
    # collection so that they can be retrieved later.
    # TODO(b/122081673): this likely goes away in TF2.0
    for variable in itertools.chain(self.trainable_variables,
                                    self.non_trainable_variables,
                                    self.local_variables):
      tf.add_to_collection(graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE,
                           variable)
예제 #6
0
    def __init__(self,
                 inner_model,
                 input_spec,
                 loss_fns,
                 loss_weights=None,
                 metrics=None):
        self._input_spec = input_spec

        if not loss_fns:
            raise ValueError(
                'Must specify at least one loss_fns, got: {l}'.format(
                    l=loss_fns))
        if (bool(len(loss_fns) == 1) != tf.is_tensor(inner_model.output)
                or (isinstance(inner_model.output, list)
                    and len(loss_fns) != len(inner_model.output))):
            raise ValueError(
                'Must specify the same number of loss_fns as model '
                'outputs.\nloss_fns: {l}\nmodel outputs: {o}'.format(
                    l=loss_fns, o=inner_model.output))
        self._loss_fns = loss_fns

        if loss_weights is None:
            loss_weights = [1.0] * len(loss_fns)
        else:
            py_typecheck.check_type(loss_weights, collections.Sequence)
            if len(loss_weights) != len(loss_fns):
                raise ValueError(
                    'Must specify the same number of '
                    'loss_weights (got {llw}) as loss_fns (got {llf}).\n'
                    'loss_weights: {lw}\nloss_fns: {lf}'.format(
                        lw=loss_weights,
                        llw=len(loss_weights),
                        lf=loss_fns,
                        llf=len(loss_fns)))
        self._loss_weights = loss_weights
        self._keras_model = inner_model
        self._metrics = metrics if metrics is not None else []

        # This is defined here so that it closes over the `loss_fn`.
        class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
            """A `tf.keras.metrics.Metric` wrapper for the loss function."""
            def __init__(self, name='loss', dtype=tf.float32):
                super().__init__(name, dtype)
                self._loss_fns = loss_fns
                self._loss_weights = loss_weights

            def update_state(self, y_true, y_pred, sample_weight=None):
                if len(self._loss_fns) == 1:
                    batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
                    y_true = tf.cast(y_true, self._dtype)
                    y_pred = tf.cast(y_pred, self._dtype)
                    batch_loss = self._loss_fns[0](y_true, y_pred)

                else:
                    batch_loss = tf.zeros(())
                    for i in range(len(self._loss_fns)):
                        y_t = tf.cast(y_true[i], self._dtype)
                        y_p = tf.cast(y_pred[i], self._dtype)
                        batch_loss += self._loss_weights[i] * self._loss_fns[
                            i](y_t, y_p)

                    batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype)

                return super().update_state(batch_loss, batch_size)

        self._loss_metric = _WeightedMeanLossMetric()

        metric_variable_type_dict = tf.nest.map_structure(
            tf.TensorSpec.from_tensor, self.report_local_outputs())
        federated_local_outputs_type = tff.FederatedType(
            metric_variable_type_dict, tff.CLIENTS)

        def federated_output(local_outputs):
            results = collections.OrderedDict()
            for metric, variables in zip(self.get_metrics(), local_outputs):
                results[metric.name] = federated_aggregate_keras_metric(
                    type(metric), metric.get_config(), variables)
            return results

        self._federated_output_computation = tff.federated_computation(
            federated_output, federated_local_outputs_type)

        # Keras creates variables that are not added to any collection, making it
        # impossible for TFF to extract them and create the appropriate initializer
        # before call a tff.Computation. Here we store them in a TFF specific
        # collection so that they can be retrieved later.
        # TODO(b/122081673): this likely goes away in TF2.0
        for variable in itertools.chain(self.trainable_variables,
                                        self.non_trainable_variables,
                                        self.local_variables):
            tf.compat.v1.add_to_collection(
                graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
 def federated_output_computation(self):
     return tff_core.federated_computation(lambda x: x)
예제 #8
0
    def __init__(self,
                 inner_model,
                 input_spec,
                 loss_fns,
                 loss_weights=None,
                 metrics=None):
        self._input_spec = input_spec

        if not loss_fns:
            raise ValueError(
                'Must specify at least one loss_fns, got: {l}'.format(
                    l=loss_fns))
        if (len(tf.nest.flatten(loss_fns)) != len(
                tf.nest.flatten(inner_model.output))):
            raise ValueError(
                'Must specify the same number of loss_fns as model '
                'outputs.\nloss_fns: {l}\nmodel outputs: {o}'.format(
                    l=loss_fns, o=inner_model.output))
        self._loss_fns = loss_fns

        if loss_weights is None:
            loss_weights = [1.0] * len(loss_fns)
        else:
            py_typecheck.check_type(loss_weights, collections.Sequence)
            if len(loss_weights) != len(loss_fns):
                raise ValueError(
                    'Must specify the same number of '
                    'loss_weights (got {llw}) as loss_fns (got {llf}).\n'
                    'loss_weights: {lw}\nloss_fns: {lf}'.format(
                        lw=loss_weights,
                        llw=len(loss_weights),
                        lf=loss_fns,
                        llf=len(loss_fns)))
        self._loss_weights = loss_weights
        self._keras_model = inner_model
        self._metrics = metrics if metrics is not None else []

        # This is defined here so that it closes over the `loss_fn`.
        class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
            """A `tf.keras.metrics.Metric` wrapper for the loss function."""
            def __init__(self, name='loss', dtype=tf.float32):
                super().__init__(name, dtype)
                self._loss_fns = loss_fns
                self._loss_weights = loss_weights

            def update_state(self, y_true, y_pred, sample_weight=None):
                if len(self._loss_fns) == 1:
                    batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
                    y_true = tf.cast(y_true, self._dtype)
                    y_pred = tf.cast(y_pred, self._dtype)
                    batch_loss = self._loss_fns[0](y_true, y_pred)

                else:
                    batch_loss = tf.zeros(())
                    for i in range(len(self._loss_fns)):
                        y_t = tf.cast(y_true[i], self._dtype)
                        y_p = tf.cast(y_pred[i], self._dtype)
                        batch_loss += self._loss_weights[i] * self._loss_fns[
                            i](y_t, y_p)

                    batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype)

                return super().update_state(batch_loss, batch_size)

        self._loss_metric = _WeightedMeanLossMetric()

        metric_variable_type_dict = tf.nest.map_structure(
            tf.TensorSpec.from_tensor, self.report_local_outputs())
        federated_local_outputs_type = tff.FederatedType(
            metric_variable_type_dict, tff.CLIENTS)

        def federated_output(local_outputs):
            return federated_aggregate_keras_metric(self.get_metrics(),
                                                    local_outputs)

        self._federated_output_computation = tff.federated_computation(
            federated_output, federated_local_outputs_type)