def test_returns_model_weights_for_model_callable(self):
     weights_type = model_utils.weights_type_from_model(TestModel)
     self.assertEqual(
         tff_core.NamedTupleType([('trainable', [
             tff_core.TensorType(tf.float32, [3]),
             tff_core.TensorType(tf.float32, [1]),
         ]), ('non_trainable', [
             tff_core.TensorType(tf.int32),
         ])]), weights_type)
Exemplo n.º 2
0
    def benchmark_fc_api_mnist(self):
        """Code adapted from FC API tutorial ipynb."""
        n_rounds = 10

        batch_type = tff.NamedTupleType([
            ("x", tff.TensorType(tf.float32, [None, 784])),
            ("y", tff.TensorType(tf.int32, [None]))
        ])

        model_type = tff.NamedTupleType([
            ("weights", tff.TensorType(tf.float32, [784, 10])),
            ("bias", tff.TensorType(tf.float32, [10]))
        ])

        local_data_type = tff.SequenceType(batch_type)

        server_model_type = tff.FederatedType(model_type,
                                              tff.SERVER,
                                              all_equal=True)
        client_data_type = tff.FederatedType(local_data_type, tff.CLIENTS)

        server_float_type = tff.FederatedType(tf.float32,
                                              tff.SERVER,
                                              all_equal=True)

        computation_building_start = time.time()

        # pylint: disable=missing-docstring
        @tff.tf_computation(model_type, batch_type)
        def batch_loss(model, batch):
            predicted_y = tf.nn.softmax(
                tf.matmul(batch.x, model.weights) + model.bias)
            return -tf.reduce_mean(
                tf.reduce_sum(tf.one_hot(batch.y, 10) * tf.log(predicted_y),
                              reduction_indices=[1]))

        initial_model = {
            "weights": np.zeros([784, 10], dtype=np.float32),
            "bias": np.zeros([10], dtype=np.float32)
        }

        @tff.tf_computation(model_type, batch_type, tf.float32)
        def batch_train(initial_model, batch, learning_rate):
            model_vars = tff.utils.get_variables("v", model_type)
            init_model = tff.utils.assign(model_vars, initial_model)

            optimizer = tf.train.GradientDescentOptimizer(learning_rate)
            with tf.control_dependencies([init_model]):
                train_model = optimizer.minimize(batch_loss(model_vars, batch))

            with tf.control_dependencies([train_model]):
                return tff.utils.identity(model_vars)

        @tff.federated_computation(model_type, tf.float32, local_data_type)
        def local_train(initial_model, learning_rate, all_batches):
            @tff.federated_computation(model_type, batch_type)
            def batch_fn(model, batch):
                return batch_train(model, batch, learning_rate)

            return tff.sequence_reduce(all_batches, initial_model, batch_fn)

        @tff.federated_computation(server_model_type, server_float_type,
                                   client_data_type)
        def federated_train(model, learning_rate, data):
            return tff.federated_average(
                tff.federated_map(local_train, [
                    tff.federated_broadcast(model),
                    tff.federated_broadcast(learning_rate), data
                ]))

        computation_building_stop = time.time()
        building_time = computation_building_stop - computation_building_start
        self.report_benchmark(name="computation_building_time, FC API",
                              wall_time=building_time,
                              iters=1)

        model = initial_model
        learning_rate = 0.1

        federated_data = generate_fake_mnist_data()

        execution_array = []
        for _ in range(n_rounds):
            execution_start = time.time()
            model = federated_train(model, learning_rate, federated_data)
            execution_stop = time.time()
            execution_array.append(execution_stop - execution_start)

        self.report_benchmark(name="Average per round execution time, FC API",
                              wall_time=np.mean(execution_array),
                              iters=n_rounds,
                              extras={"std_dev": np.std(execution_array)})
Exemplo n.º 3
0
  Args:
    process: A measured process to validate.

  Returns:
    `True` iff the process is a validate aggregation process, otherwise `False`.
  """
  next_type = process.next.type_signature
  return (isinstance(process, tff.templates.MeasuredProcess) and
          _is_valid_stateful_process(process) and
          next_type.parameter[1].placement is tff.CLIENTS and
          next_type.result.result.placement is tff.SERVER)


# ============================================================================

NONE_SERVER_TYPE = tff.FederatedType(tff.NamedTupleType([]), tff.SERVER)


def _wrap_in_measured_process(
    stateful_fn: Union[tff.utils.StatefulBroadcastFn,
                       tff.utils.StatefulAggregateFn],
    input_type: tff.Type) -> tff.templates.MeasuredProcess:
  """Converts a `tff.utils.StatefulFn` to a `tff.templates.MeasuredProcess`."""
  py_typecheck.check_type(
      stateful_fn,
      (tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn))

  @tff.federated_computation()
  def initialize_comp():
    if not isinstance(stateful_fn.initialize, tff.Computation):
      initialize = tff.tf_computation(stateful_fn.initialize)
 def input_spec(self):
     return tff_core.NamedTupleType((
         tff_core.TensorSpec(tf.float32, [3]),
         tff_core.TensorSpec(tf.float32, [1]),
     ))