def test_execution_stateful_optimizer(self):
        client_work_process = client_works.build_model_delta_client_work(
            model_examples.LinearRegression, sgdm.build_sgdm(0.1,
                                                             momentum=0.9))
        data = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[[1.0, 2.0], [3.0, 4.0]],
                y=[[5.0], [6.0]],
            )).batch(2)
        data = [data, data.repeat(2)]  # 1st client has 2 examples, 2nd has 4.
        model_weights = model_utils.ModelWeights(trainable=[[[0.0], [0.0]],
                                                            0.0],
                                                 non_trainable=[0.0])
        client_model_weights = [model_weights] * 2

        state = client_work_process.initialize()
        output = client_work_process.next(state, client_model_weights, data)

        expected_result = (
            client_works.ClientResult([[[-1.15], [-1.7]], -0.55], 2.0),
            client_works.ClientResult([[[-1.46], [-2.26]], -0.8], 4.0),
        )

        self.assertEqual((), output.state)
        for i in range(len(expected_result)):
            self.assertAllClose(expected_result[i].update,
                                output.result[i].update)
            self.assertAllClose(expected_result[i].update_weight,
                                output.result[i].update_weight)
        self.assertEqual((), output.measurements)
Exemplo n.º 2
0
 def initial_weights(self):
   return model_utils.ModelWeights(
       trainable={
           'a': tf.constant([[0.0], [0.0]]),
           'b': tf.constant(0.0)
       },
       non_trainable={'c': 0.0})
Exemplo n.º 3
0
 def initial_weights(self):
   return model_utils.ModelWeights(
       trainable=[
           tf.constant([[0.0], [0.0]]),
           tf.constant(0.0),
       ],
       non_trainable=[0.0])
Exemplo n.º 4
0
 def next_fn(state, weights, update):
     return MeasuredProcessOutput(
         state,
         intrinsics.federated_zip(
             model_utils.ModelWeights(
                 federated_add(weights['trainable'], update), ())),
         server_zero())
Exemplo n.º 5
0
    def test_type_properties(self):
        mw_type = computation_types.to_type(
            model_utils.ModelWeights(trainable=(tf.float32, tf.float32),
                                     non_trainable=tf.float32))

        finalizer = finalizers.build_apply_optimizer_finalizer(
            sgdm.build_sgdm(1.0), mw_type)
        self.assertIsInstance(finalizer, finalizers.FinalizerProcess)

        expected_param_weights_type = computation_types.at_server(mw_type)
        expected_param_update_type = computation_types.at_server(
            mw_type.trainable)
        expected_result_type = computation_types.at_server(mw_type)
        expected_state_type = computation_types.at_server(())
        expected_measurements_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        expected_initialize_type.check_equivalent_to(
            finalizer.initialize.type_signature)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                weights=expected_param_weights_type,
                update=expected_param_update_type),
            result=MeasuredProcessOutput(expected_state_type,
                                         expected_result_type,
                                         expected_measurements_type))
        expected_next_type.check_equivalent_to(finalizer.next.type_signature)
Exemplo n.º 6
0
def state_with_new_model_weights(
    server_state: ServerState,
    trainable_weights: List[np.ndarray],
    non_trainable_weights: List[np.ndarray],
) -> ServerState:
    """Returns a `ServerState` with updated model weights.

  Args:
    server_state: A server state object returned by an iterative training
      process like `tff.learning.build_federated_averaging_process`.
    trainable_weights: A list of `numpy` arrays in the order of the original
      model's `trainable_variables`.
    non_trainable_weights: A list of `numpy` arrays in the order of the original
      model's `non_trainable_variables`.

  Returns:
    A new server `ServerState` object which can be passed to the `next` method
    of the iterative process.
  """
    py_typecheck.check_type(server_state, ServerState)
    leaf_types = (int, float, np.ndarray, tf.Tensor)

    def assert_weight_lists_match(old_value, new_value):
        """Assert two flat lists of ndarrays or tensors match."""
        if isinstance(new_value, leaf_types) and isinstance(
                old_value, leaf_types):
            if (old_value.dtype != new_value.dtype
                    or old_value.shape != new_value.shape):
                raise TypeError('Element is not the same tensor type. old '
                                f'({old_value.dtype}, {old_value.shape}) != '
                                f'new ({new_value.dtype}, {new_value.shape})')
        elif (isinstance(new_value, collections.abc.Sequence)
              and isinstance(old_value, collections.abc.Sequence)):
            if len(old_value) != len(new_value):
                raise TypeError(
                    'Model weights have different lengths: '
                    f'(old) {len(old_value)} != (new) {len(new_value)})\n'
                    f'Old values: {old_value}\nNew values: {new_value}')
            for old, new in zip(old_value, new_value):
                assert_weight_lists_match(old, new)
        else:
            raise TypeError(
                'Model weights structures contains types that cannot be '
                'handled.\nOld weights structure: {old}\n'
                'New weights structure: {new}\n'
                'Must be one of (int, float, np.ndarray, tf.Tensor, '
                'collections.abc.Sequence)'.format(
                    old=tf.nest.map_structure(type, old_value),
                    new=tf.nest.map_structure(type, new_value)))

    assert_weight_lists_match(server_state.model.trainable, trainable_weights)
    assert_weight_lists_match(server_state.model.non_trainable,
                              non_trainable_weights)
    new_server_state = ServerState(
        model=model_utils.ModelWeights(trainable=trainable_weights,
                                       non_trainable=non_trainable_weights),
        optimizer_state=server_state.optimizer_state,
        delta_aggregate_state=server_state.delta_aggregate_state,
        model_broadcast_state=server_state.model_broadcast_state)
    return new_server_state
Exemplo n.º 7
0
 def initial_weights(self):
     return model_utils.ModelWeights(
         trainable=collections.OrderedDict([
             ('a', tf.constant([[0.0], [0.0]])),
             ('b', tf.constant(0.0)),
         ]),
         non_trainable=collections.OrderedDict([('c', 0.0)]),
     )
Exemplo n.º 8
0
 def next_fn(state, weights, updates):
     new_weights = intrinsics.federated_map(
         tensorflow_computation.tf_computation(lambda x, y: x + y),
         (weights.trainable, updates))
     new_weights = intrinsics.federated_zip(
         model_utils.ModelWeights(new_weights, ()))
     return measured_process.MeasuredProcessOutput(state, new_weights,
                                                   empty_at_server())
    def test_construction(self, weighted):
        aggregation_factory = (mean.MeanFactory()
                               if weighted else sum_factory.SumFactory())
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            model_update_aggregation_factory=aggregation_factory)

        if weighted:
            aggregate_state = collections.OrderedDict(value_sum_process=(),
                                                      weight_sum_process=())
            aggregate_metrics = collections.OrderedDict(mean_value=(),
                                                        mean_weight=())
        else:
            aggregate_state = ()
            aggregate_metrics = ()

        server_state_type = computation_types.FederatedType(
            optimizer_utils.ServerState(model=model_utils.ModelWeights(
                trainable=[
                    computation_types.TensorType(tf.float32, [2, 1]),
                    computation_types.TensorType(tf.float32)
                ],
                non_trainable=[computation_types.TensorType(tf.float32)]),
                                        optimizer_state=[tf.int64],
                                        delta_aggregate_state=aggregate_state,
                                        model_broadcast_state=()),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=None,
                                           result=server_state_type),
            iterative_process.initialize.type_signature)

        dataset_type = computation_types.FederatedType(
            computation_types.SequenceType(
                collections.OrderedDict(
                    x=computation_types.TensorType(tf.float32, [None, 2]),
                    y=computation_types.TensorType(tf.float32, [None, 1]))),
            placements.CLIENTS)
        metrics_type = computation_types.FederatedType(
            collections.OrderedDict(
                broadcast=(),
                aggregation=aggregate_metrics,
                train=collections.OrderedDict(
                    loss=computation_types.TensorType(tf.float32),
                    num_examples=computation_types.TensorType(tf.int32)),
                stat=collections.OrderedDict(
                    num_examples=computation_types.TensorType(tf.float32))),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=collections.OrderedDict(
                server_state=server_state_type,
                federated_dataset=dataset_type,
            ),
                                           result=(server_state_type,
                                                   metrics_type)),
            iterative_process.next.type_signature)
Exemplo n.º 10
0
class ApplyOptimizerFinalizerComputationTest(tf.test.TestCase,
                                             parameterized.TestCase):

  def test_type_properties(self):
    mw_type = computation_types.to_type(
        model_utils.ModelWeights(
            trainable=(tf.float32, tf.float32), non_trainable=tf.float32))

    finalizer = finalizers.build_apply_optimizer_finalizer(
        sgdm.build_sgdm(1.0), mw_type)
    self.assertIsInstance(finalizer, finalizers.FinalizerProcess)

    expected_param_weights_type = computation_types.at_server(mw_type)
    expected_param_update_type = computation_types.at_server(mw_type.trainable)
    expected_result_type = computation_types.at_server(mw_type)
    expected_state_type = computation_types.at_server(
        computation_types.to_type(
            collections.OrderedDict([(optimizer_base.LEARNING_RATE_KEY,
                                      tf.float32)])))
    expected_measurements_type = computation_types.at_server(())

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    expected_initialize_type.check_equivalent_to(
        finalizer.initialize.type_signature)

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type,
            weights=expected_param_weights_type,
            update=expected_param_update_type),
        result=MeasuredProcessOutput(expected_state_type, expected_result_type,
                                     expected_measurements_type))
    expected_next_type.check_equivalent_to(finalizer.next.type_signature)

  @parameterized.named_parameters(
      ('not_struct', computation_types.TensorType(tf.float32)),
      ('federated_type', MODEL_WEIGHTS_TYPE),
      ('model_weights_of_federated_types',
       computation_types.to_type(
           model_utils.ModelWeights(SERVER_FLOAT, SERVER_FLOAT))),
      ('not_model_weights', computation_types.to_type(
          (tf.float32, tf.float32))),
      ('function_type', computation_types.FunctionType(None,
                                                       MODEL_WEIGHTS_TYPE)),
      ('sequence_type', computation_types.SequenceType(
          MODEL_WEIGHTS_TYPE.member)))
  def test_incorrect_value_type_raises(self, bad_type):
    with self.assertRaises(TypeError):
      finalizers.build_apply_optimizer_finalizer(sgdm.build_sgdm(1.0), bad_type)

  def test_unexpected_optimizer_fn_raises(self):
    optimizer = tf.keras.optimizers.SGD(1.0)
    with self.assertRaises(TypeError):
      finalizers.build_apply_optimizer_finalizer(optimizer,
                                                 MODEL_WEIGHTS_TYPE.member)
Exemplo n.º 11
0
    def test_execution(self):
        finalizer = finalizers.build_apply_optimizer_finalizer(
            sgdm.build_sgdm(1.0), MODEL_WEIGHTS_TYPE.member)

        weights = model_utils.ModelWeights(1.0, ())
        update = 0.1
        output = finalizer.next(finalizer.initialize(), weights, update)
        self.assertEqual((), output.state)
        self.assertAllClose(0.9, output.result.trainable)
        self.assertEqual((), output.measurements)
Exemplo n.º 12
0
 def next_fn(state, weights, update):
     optimizer_state, new_trainable_weights = intrinsics.federated_map(
         next_tf, (state, weights.trainable, update))
     new_weights = intrinsics.federated_zip(
         model_utils.ModelWeights(new_trainable_weights,
                                  weights.non_trainable))
     empty_measurements = intrinsics.federated_value((), placements.SERVER)
     return measured_process.MeasuredProcessOutput(optimizer_state,
                                                   new_weights,
                                                   empty_measurements)
 def _model_fn_with_zero_weights():
     linear_regression_model = model_examples.LinearRegression()
     weights = model_utils.ModelWeights.from_model(
         linear_regression_model)
     zero_trainable = [tf.zeros_like(x) for x in weights.trainable]
     zero_non_trainable = [
         tf.zeros_like(x) for x in weights.non_trainable
     ]
     zero_weights = model_utils.ModelWeights(
         trainable=zero_trainable, non_trainable=zero_non_trainable)
     zero_weights.assign_weights_to(linear_regression_model)
     return linear_regression_model
Exemplo n.º 14
0
 def _model_fn_with_one_weights():
     linear_regression_model = model_examples.LinearRegression
     weights = model_utils.ModelWeights.from_model(
         linear_regression_model)
     ones_trainable = [tf.ones_like(x) for x in weights.trainable]
     ones_non_trainable = [
         tf.ones_like(x) for x in weights.non_trainable
     ]
     ones_weights = model_utils.ModelWeights(
         trainable=ones_trainable, non_trainable=ones_non_trainable)
     ones_weights.assign_weights_to(linear_regression_model)
     return linear_regression_model
Exemplo n.º 15
0
def state_with_new_model_weights(server_state, trainable_weights,
                                 non_trainable_weights):
    """Returns a `ServerState` with updated model weights.

  Args:
    server_state: A server state object returned by an iterative training
      process like `tff.learning.build_federated_averaging_process`.
    trainable_weights: A list of `numpy` arrays in the order of the original
      model's `trainable_variables`.
    non_trainable_weights: A list of `numpy` arrays in the order of the original
      model's `non_trainable_variables`.

  Returns:
    A new server `ServerState` object which can be passed to the `next` method
    of the iterative process.
  """
    # TODO(b/123092620): Simplify this.
    py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple)

    def pack_values(old, new_values, name):
        """Packs new_values in an OrderedDict matching old."""
        if len(old) != len(new_values):
            raise ValueError('Lengths differ for {} weights: {} vs {}'.format(
                name, len(old), len(new_values)))
        tuples = []
        for (key,
             old_value), new_value in zip(anonymous_tuple.to_elements(old),
                                          new_values):
            if (old_value.dtype != new_value.dtype
                    or old_value.shape != new_value.shape):
                raise ValueError(
                    'The shapes or dtypes do not match for {} weight {}:\n'
                    'current weights: shape {} dtype {}\n'
                    '    new weights: shape {} dtype {}'.format(
                        name, key, old_value.shape, old_value.dtype,
                        new_value.shape, new_value.dtype))

            tuples.append((key, new_value))
        return collections.OrderedDict(tuples)

    renamed_new_weights = model_utils.ModelWeights(
        trainable=pack_values(server_state.model.trainable, trainable_weights,
                              'trainable'),
        non_trainable=pack_values(server_state.model.non_trainable,
                                  non_trainable_weights, 'non_trainable'))
    # TODO(b/123092620): We can't use tff.utils.update_state because this
    # is an AnonymousTuple, not a ServerState. We should do something
    # that doesn't mention every entry in the state.
    return ServerState(
        model=renamed_new_weights,
        optimizer_state=server_state.optimizer_state,
        delta_aggregate_state=server_state.delta_aggregate_state,
        model_broadcast_state=server_state.model_broadcast_state)
Exemplo n.º 16
0
  def test_non_federated_init_next_raises(self):
    initialize_fn = tensorflow_computation.tf_computation(lambda: 0)

    @tensorflow_computation.tf_computation(
        tf.int32,
        computation_types.to_type(model_utils.ModelWeights(tf.float32,
                                                           ())), tf.float32)
    def next_fn(state, weights, update):
      new_weigths = model_utils.ModelWeights(weights.trainable + update, ())
      return MeasuredProcessOutput(state, new_weigths, 0)

    with self.assertRaises(errors.TemplateNotFederatedError):
      finalizers.FinalizerProcess(initialize_fn, next_fn)
Exemplo n.º 17
0
 def test_model_weights_from_python_structure(self):
     trainable_weights = [tf.constant([1., 1.])]
     non_trainable_weights = [tf.constant(1)]
     model_weights = model_utils.ModelWeights(
         trainable=trainable_weights, non_trainable=non_trainable_weights)
     python_weights_structure = collections.OrderedDict(
         trainable=trainable_weights, non_trainable=non_trainable_weights)
     model_weights_from_python_structure = model_utils.ModelWeights.from_python_structure(
         python_weights_structure)
     self.assertEqual(model_weights.trainable,
                      model_weights_from_python_structure.trainable)
     self.assertEqual(model_weights.non_trainable,
                      model_weights_from_python_structure.non_trainable)
Exemplo n.º 18
0
  def test_execution_with_stateless_tff_optimizer(self):
    finalizer = finalizers.build_apply_optimizer_finalizer(
        sgdm.build_sgdm(1.0), MODEL_WEIGHTS_TYPE.member)

    weights = model_utils.ModelWeights(1.0, ())
    update = 0.1
    optimizer_state = finalizer.initialize()
    for i in range(5):
      output = finalizer.next(optimizer_state, weights, update)
      optimizer_state = output.state
      weights = output.result
      self.assertEqual(1.0, optimizer_state[optimizer_base.LEARNING_RATE_KEY])
      self.assertAllClose(1.0 - 0.1 * (i + 1), weights.trainable)
      self.assertEqual((), output.measurements)
Exemplo n.º 19
0
  def test_orchestration_typecheck(self):
    iterative_process = federated_sgd.build_federated_sgd_process(
        model_fn=model_examples.LinearRegression)

    expected_model_weights_type = model_utils.ModelWeights(
        collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                 ('b', tf.float32)]),
        collections.OrderedDict([('c', tf.float32)]))

    # ServerState consists of a model and optimizer_state. The optimizer_state
    # is provided by TensorFlow, TFF doesn't care what the actual value is.
    expected_federated_server_state_type = tff.FederatedType(
        optimizer_utils.ServerState(expected_model_weights_type,
                                    test.AnyType()),
        placement=tff.SERVER,
        all_equal=True)

    expected_federated_dataset_type = tff.FederatedType(
        tff.SequenceType(
            model_examples.LinearRegression.make_batch(
                tff.TensorType(tf.float32, [None, 2]),
                tff.TensorType(tf.float32, [None, 1]))),
        tff.CLIENTS,
        all_equal=False)

    expected_model_output_types = tff.FederatedType(
        collections.OrderedDict([
            ('loss', tff.TensorType(tf.float32, [])),
            ('num_examples', tff.TensorType(tf.int32, [])),
        ]),
        tff.SERVER,
        all_equal=True)

    # `initialize` is expected to be a funcion of no arguments to a ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=None, result=expected_federated_server_state_type),
        iterative_process.initialize.type_signature)

    # `next` is expected be a function of (ServerState, Datasets) to
    # ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
            result=(expected_federated_server_state_type,
                    expected_model_output_types)),
        iterative_process.next.type_signature)
Exemplo n.º 20
0
    def test_orchestration_type_signature(self):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.TrainableLinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0
                                                             ))

        expected_model_weights_type = model_utils.ModelWeights(
            collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                     ('b', tf.float32)]),
            collections.OrderedDict([('c', tf.float32)]))

        # ServerState consists of a model and optimizer_state. The optimizer_state
        # is provided by TensorFlow, TFF doesn't care what the actual value is.
        expected_federated_server_state_type = tff.FederatedType(
            optimizer_utils.ServerState(expected_model_weights_type,
                                        test.AnyType(), test.AnyType(),
                                        test.AnyType()),
            placement=tff.SERVER,
            all_equal=True)

        expected_federated_dataset_type = tff.FederatedType(tff.SequenceType(
            model_examples.TrainableLinearRegression().input_spec),
                                                            tff.CLIENTS,
                                                            all_equal=False)

        expected_model_output_types = tff.FederatedType(
            collections.OrderedDict([
                ('loss', tff.TensorType(tf.float32, [])),
                ('num_examples', tff.TensorType(tf.int32, [])),
            ]),
            tff.SERVER,
            all_equal=True)

        # `initialize` is expected to be a funcion of no arguments to a ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=None,
                             result=expected_federated_server_state_type),
            iterative_process.initialize.type_signature)

        # `next` is expected be a function of (ServerState, Datasets) to
        # ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
                             result=(expected_federated_server_state_type,
                                     expected_model_output_types)),
            iterative_process.next.type_signature)
Exemplo n.º 21
0
  def test_construction(self):
    iterative_process = optimizer_utils.build_model_delta_optimizer_process(
        model_fn=model_examples.LinearRegression,
        model_to_client_delta_fn=DummyClientDeltaFn,
        server_optimizer_fn=tf.keras.optimizers.SGD)

    server_state_type = computation_types.FederatedType(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=[
                    computation_types.TensorType(tf.float32, [2, 1]),
                    computation_types.TensorType(tf.float32)
                ],
                non_trainable=[computation_types.TensorType(tf.float32)]),
            optimizer_state=[tf.int64],
            delta_aggregate_state=(),
            model_broadcast_state=()), placements.SERVER)

    self.assertEqual(
        str(iterative_process.initialize.type_signature),
        str(
            computation_types.FunctionType(
                parameter=None, result=server_state_type)))

    dataset_type = computation_types.FederatedType(
        computation_types.SequenceType(
            collections.OrderedDict(
                x=computation_types.TensorType(tf.float32, [None, 2]),
                y=computation_types.TensorType(tf.float32, [None, 1]))),
        placements.CLIENTS)

    metrics_type = computation_types.FederatedType(
        collections.OrderedDict(
            broadcast=(),
            aggregation=(),
            train=collections.OrderedDict(
                loss=computation_types.TensorType(tf.float32),
                num_examples=computation_types.TensorType(tf.int32))),
        placements.SERVER)

    self.assertEqual(
        str(iterative_process.next.type_signature),
        str(
            computation_types.FunctionType(
                parameter=collections.OrderedDict(
                    server_state=server_state_type,
                    federated_dataset=dataset_type,
                ),
                result=(server_state_type, metrics_type))))
Exemplo n.º 22
0
 def test_model_weights_from_tff_struct(self):
     trainable_weights = [tf.constant([1., 1.])]
     non_trainable_weights = [tf.constant(1)]
     model_weights = model_utils.ModelWeights(
         trainable=trainable_weights, non_trainable=non_trainable_weights)
     tff_struct = structure.Struct([
         ('trainable', structure.from_container(trainable_weights)),
         ('non_trainable', structure.from_container(non_trainable_weights))
     ])
     model_weights_from_tff_struct = model_utils.ModelWeights.from_tff_result(
         tff_struct)
     self.assertEqual(model_weights.trainable,
                      model_weights_from_tff_struct.trainable)
     self.assertEqual(model_weights.non_trainable,
                      model_weights_from_tff_struct.non_trainable)
Exemplo n.º 23
0
    def test_type_properties(self, weighting):
        model_fn = model_examples.LinearRegression
        optimizer = sgdm.build_sgdm(learning_rate=0.1, momentum=0.9)
        client_work_process = mime._build_mime_lite_client_work(
            model_fn, optimizer, weighting)
        self.assertIsInstance(client_work_process,
                              client_works.ClientWorkProcess)

        mw_type = model_utils.ModelWeights(
            trainable=computation_types.to_type([(tf.float32, (2, 1)),
                                                 tf.float32]),
            non_trainable=computation_types.to_type([tf.float32]))
        expected_param_model_weights_type = computation_types.at_clients(
            mw_type)
        expected_param_data_type = computation_types.at_clients(
            computation_types.SequenceType(
                computation_types.to_type(model_fn().input_spec)))
        expected_result_type = computation_types.at_clients(
            client_works.ClientResult(
                update=mw_type.trainable,
                update_weight=computation_types.TensorType(tf.float32)))
        expected_optimizer_state_type = type_conversions.type_from_tensors(
            optimizer.initialize(
                type_conversions.type_to_tf_tensor_specs(mw_type.trainable)))
        expected_aggregator_type = computation_types.to_type(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))
        expected_state_type = computation_types.at_server(
            (expected_optimizer_state_type, expected_aggregator_type))
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(train=collections.OrderedDict(
                loss=tf.float32, num_examples=tf.int32)))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        expected_initialize_type.check_equivalent_to(
            client_work_process.initialize.type_signature)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                weights=expected_param_model_weights_type,
                client_data=expected_param_data_type),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, expected_result_type,
                expected_measurements_type))
        expected_next_type.check_equivalent_to(
            client_work_process.next.type_signature)
Exemplo n.º 24
0
  def test_state_with_new_model_weights(self):
    trainable = [np.array([1.0, 2.0]), np.array([[1.0]])]
    non_trainable = [np.array(1)]
    state = anonymous_tuple.from_container(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=trainable, non_trainable=non_trainable),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0)),
        recursive=True)

    new_state = optimizer_utils.state_with_new_model_weights(
        state,
        trainable_weights=[np.array([3.0, 3.0]),
                           np.array([[3.0]])],
        non_trainable_weights=[np.array(3)])
    self.assertAllClose(
        new_state.model.trainable,
        [np.array([3.0, 3.0]), np.array([[3.0]])])
    self.assertAllClose(new_state.model.non_trainable, [3])

    with self.assertRaisesRegex(TypeError, 'tensor type'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([[3]])],
          non_trainable_weights=[np.array(3.0)])

    with self.assertRaisesRegex(TypeError, 'tensor type'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([3.0])],
          non_trainable_weights=[np.array(3)])

    with self.assertRaisesRegex(TypeError, 'different lengths'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0])],
          non_trainable_weights=[np.array(3)])

    with self.assertRaisesRegex(TypeError, 'cannot be handled'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights={'a': np.array([3.0, 3.0])},
          non_trainable_weights=[np.array(3)])
Exemplo n.º 25
0
  def test_execution_with_nearly_stateless_keras_optimizer(self):
    server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
    # Note that SGD only maintains a counter of how many times it has been
    # called. No other state is used.
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, MODEL_WEIGHTS_TYPE.member)

    weights = model_utils.ModelWeights(1.0, ())
    update = 0.1
    optimizer_state = finalizer.initialize()
    for i in range(5):
      output = finalizer.next(optimizer_state, weights, update)
      optimizer_state = output.state
      weights = output.result
      # We check that the optimizer state is the number of calls.
      self.assertEqual([i + 1], optimizer_state)
      self.assertAllClose(1.0 - 0.1 * (i + 1), weights.trainable)
      self.assertEqual((), output.measurements)
Exemplo n.º 26
0
  def test_execution_with_stateful_tff_optimizer(self):
    momentum = 0.5
    finalizer = finalizers.build_apply_optimizer_finalizer(
        sgdm.build_sgdm(1.0, momentum=momentum), MODEL_WEIGHTS_TYPE.member)

    weights = model_utils.ModelWeights(1.0, ())
    update = 0.1
    expected_velocity = 0.0
    optimizer_state = finalizer.initialize()
    for _ in range(5):
      output = finalizer.next(optimizer_state, weights, update)
      optimizer_state = output.state
      expected_velocity = expected_velocity * momentum + update
      self.assertNear(expected_velocity, optimizer_state['accumulator'], 1e-6)
      self.assertAllClose(weights.trainable - expected_velocity,
                          output.result.trainable)
      self.assertEqual((), output.measurements)
    weights = output.result
Exemplo n.º 27
0
  def test_state_with_new_model_weights(self):
    trainable = [('b', np.array([1.0, 2.0])), ('a', np.array([[1.0]]))]
    non_trainable = [('c', np.array(1))]
    state = anonymous_tuple.from_container(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=collections.OrderedDict(trainable),
                non_trainable=collections.OrderedDict(non_trainable)),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0)),
        recursive=True)

    new_state = optimizer_utils.state_with_new_model_weights(
        state,
        trainable_weights=[np.array([3.0, 3.0]),
                           np.array([[3.0]])],
        non_trainable_weights=[np.array(3)])
    self.assertEqual(list(new_state.model.trainable.keys()), ['b', 'a'])
    self.assertEqual(list(new_state.model.non_trainable.keys()), ['c'])
    self.assertAllClose(new_state.model.trainable['b'], [3.0, 3.0])
    self.assertAllClose(new_state.model.trainable['a'], [[3.0]])
    self.assertAllClose(new_state.model.non_trainable['c'], 3)

    with self.assertRaisesRegexp(ValueError, 'dtype'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([[3]])],
          non_trainable_weights=[np.array(3.0)])

    with self.assertRaisesRegexp(ValueError, 'shape'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([3.0])],
          non_trainable_weights=[np.array(3)])

    with self.assertRaisesRegexp(ValueError, 'Lengths differ'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0])],
          non_trainable_weights=[np.array(3)])
Exemplo n.º 28
0
    def test_state_with_new_model_weights_failure(self, new_trainable,
                                                  new_non_trainable,
                                                  expected_err_msg):
        trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)]
        non_trainable = [np.array(1), b'bytes type', 5, 2.0]
        state = optimizer_utils.ServerState(
            model=model_utils.ModelWeights(trainable=trainable,
                                           non_trainable=non_trainable),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0))

        new_trainable = trainable if new_trainable is None else new_trainable
        non_trainable = non_trainable if new_non_trainable is None else non_trainable

        with self.assertRaisesRegex(TypeError, expected_err_msg):
            optimizer_utils.state_with_new_model_weights(
                state,
                trainable_weights=new_trainable,
                non_trainable_weights=new_non_trainable)
Exemplo n.º 29
0
    def test_state_with_model_weights_success(self):
        trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)]
        non_trainable = [np.array(1), b'bytes type', 5, 2.0]

        new_trainable = [np.array([3.0, 3.0]), np.array([[3.0]]), np.int64(4)]
        new_non_trainable = [np.array(3), b'bytes check', 6, 3.0]

        state = optimizer_utils.ServerState(
            model=model_utils.ModelWeights(trainable=trainable,
                                           non_trainable=non_trainable),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0))

        new_state = optimizer_utils.state_with_new_model_weights(
            state,
            trainable_weights=new_trainable,
            non_trainable_weights=new_non_trainable)
        self.assertAllClose(new_state.model.trainable, new_trainable)
        self.assertEqual(new_state.model.non_trainable, new_non_trainable)
Exemplo n.º 30
0
    def test_type_properties(self):
        model_fn = model_examples.LinearRegression
        client_work_process = client_works.build_model_delta_client_work(
            model_fn, sgdm.build_sgdm(1.0))
        self.assertIsInstance(client_work_process,
                              client_works.ClientWorkProcess)

        mw_type = model_utils.ModelWeights(
            trainable=computation_types.to_type([(tf.float32, (2, 1)),
                                                 tf.float32]),
            non_trainable=computation_types.to_type([tf.float32]))
        expected_param_model_weights_type = computation_types.at_clients(
            mw_type)
        expected_param_data_type = computation_types.at_clients(
            computation_types.SequenceType(
                computation_types.to_type(model_fn().input_spec)))
        expected_result_type = computation_types.at_clients(
            client_works.ClientResult(
                update=mw_type.trainable,
                update_weight=computation_types.TensorType(tf.float32)))
        expected_state_type = computation_types.at_server(())
        expected_measurements_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        expected_initialize_type.check_equivalent_to(
            client_work_process.initialize.type_signature)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                weights=expected_param_model_weights_type,
                client_data=expected_param_data_type),
            result=MeasuredProcessOutput(expected_state_type,
                                         expected_result_type,
                                         expected_measurements_type))
        expected_next_type.check_equivalent_to(
            client_work_process.next.type_signature)