def test_construction(self, weighted):
        aggregation_factory = (mean.MeanFactory()
                               if weighted else sum_factory.SumFactory())
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            model_update_aggregation_factory=aggregation_factory)

        if weighted:
            aggregate_state = collections.OrderedDict(value_sum_process=(),
                                                      weight_sum_process=())
            aggregate_metrics = collections.OrderedDict(mean_value=(),
                                                        mean_weight=())
        else:
            aggregate_state = ()
            aggregate_metrics = ()

        server_state_type = computation_types.FederatedType(
            optimizer_utils.ServerState(model=model_utils.ModelWeights(
                trainable=[
                    computation_types.TensorType(tf.float32, [2, 1]),
                    computation_types.TensorType(tf.float32)
                ],
                non_trainable=[computation_types.TensorType(tf.float32)]),
                                        optimizer_state=[tf.int64],
                                        delta_aggregate_state=aggregate_state,
                                        model_broadcast_state=()),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=None,
                                           result=server_state_type),
            iterative_process.initialize.type_signature)

        dataset_type = computation_types.FederatedType(
            computation_types.SequenceType(
                collections.OrderedDict(
                    x=computation_types.TensorType(tf.float32, [None, 2]),
                    y=computation_types.TensorType(tf.float32, [None, 1]))),
            placements.CLIENTS)
        metrics_type = computation_types.FederatedType(
            collections.OrderedDict(
                broadcast=(),
                aggregation=aggregate_metrics,
                train=collections.OrderedDict(
                    loss=computation_types.TensorType(tf.float32),
                    num_examples=computation_types.TensorType(tf.int32)),
                stat=collections.OrderedDict(
                    num_examples=computation_types.TensorType(tf.float32))),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=collections.OrderedDict(
                server_state=server_state_type,
                federated_dataset=dataset_type,
            ),
                                           result=(server_state_type,
                                                   metrics_type)),
            iterative_process.next.type_signature)
예제 #2
0
 def server_init_tff():
     """Returns a state placed at `tff.SERVER`."""
     tf_init_tuple = intrinsics.federated_eval(server_init_tf,
                                               placements.SERVER)
     return intrinsics.federated_zip(
         optimizer_utils.ServerState(
             model=tf_init_tuple[0],
             optimizer_state=tf_init_tuple[1],
             delta_aggregate_state=aggregation_init_computation(),
             model_broadcast_state=broadcast_init_computation()))
예제 #3
0
    def test_orchestration_type_signature(self):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.TrainableLinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0
                                                             ))

        expected_model_weights_type = model_utils.ModelWeights(
            collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                     ('b', tf.float32)]),
            collections.OrderedDict([('c', tf.float32)]))

        # ServerState consists of a model and optimizer_state. The optimizer_state
        # is provided by TensorFlow, TFF doesn't care what the actual value is.
        expected_federated_server_state_type = tff.FederatedType(
            optimizer_utils.ServerState(expected_model_weights_type,
                                        test.AnyType(), test.AnyType(),
                                        test.AnyType()),
            placement=tff.SERVER,
            all_equal=True)

        expected_federated_dataset_type = tff.FederatedType(tff.SequenceType(
            model_examples.TrainableLinearRegression().input_spec),
                                                            tff.CLIENTS,
                                                            all_equal=False)

        expected_model_output_types = tff.FederatedType(
            collections.OrderedDict([
                ('loss', tff.TensorType(tf.float32, [])),
                ('num_examples', tff.TensorType(tf.int32, [])),
            ]),
            tff.SERVER,
            all_equal=True)

        # `initialize` is expected to be a funcion of no arguments to a ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=None,
                             result=expected_federated_server_state_type),
            iterative_process.initialize.type_signature)

        # `next` is expected be a function of (ServerState, Datasets) to
        # ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
                             result=(expected_federated_server_state_type,
                                     expected_model_output_types)),
            iterative_process.next.type_signature)
예제 #4
0
  def test_orchestration_typecheck(self):
    iterative_process = federated_sgd.build_federated_sgd_process(
        model_fn=model_examples.LinearRegression)

    expected_model_weights_type = model_utils.ModelWeights(
        collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                 ('b', tf.float32)]),
        collections.OrderedDict([('c', tf.float32)]))

    # ServerState consists of a model and optimizer_state. The optimizer_state
    # is provided by TensorFlow, TFF doesn't care what the actual value is.
    expected_federated_server_state_type = tff.FederatedType(
        optimizer_utils.ServerState(expected_model_weights_type,
                                    test.AnyType()),
        placement=tff.SERVER,
        all_equal=True)

    expected_federated_dataset_type = tff.FederatedType(
        tff.SequenceType(
            model_examples.LinearRegression.make_batch(
                tff.TensorType(tf.float32, [None, 2]),
                tff.TensorType(tf.float32, [None, 1]))),
        tff.CLIENTS,
        all_equal=False)

    expected_model_output_types = tff.FederatedType(
        collections.OrderedDict([
            ('loss', tff.TensorType(tf.float32, [])),
            ('num_examples', tff.TensorType(tf.int32, [])),
        ]),
        tff.SERVER,
        all_equal=True)

    # `initialize` is expected to be a funcion of no arguments to a ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=None, result=expected_federated_server_state_type),
        iterative_process.initialize.type_signature)

    # `next` is expected be a function of (ServerState, Datasets) to
    # ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
            result=(expected_federated_server_state_type,
                    expected_model_output_types)),
        iterative_process.next.type_signature)
예제 #5
0
  def test_construction(self):
    iterative_process = optimizer_utils.build_model_delta_optimizer_process(
        model_fn=model_examples.LinearRegression,
        model_to_client_delta_fn=DummyClientDeltaFn,
        server_optimizer_fn=tf.keras.optimizers.SGD)

    server_state_type = computation_types.FederatedType(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=[
                    computation_types.TensorType(tf.float32, [2, 1]),
                    computation_types.TensorType(tf.float32)
                ],
                non_trainable=[computation_types.TensorType(tf.float32)]),
            optimizer_state=[tf.int64],
            delta_aggregate_state=(),
            model_broadcast_state=()), placements.SERVER)

    self.assertEqual(
        str(iterative_process.initialize.type_signature),
        str(
            computation_types.FunctionType(
                parameter=None, result=server_state_type)))

    dataset_type = computation_types.FederatedType(
        computation_types.SequenceType(
            collections.OrderedDict(
                x=computation_types.TensorType(tf.float32, [None, 2]),
                y=computation_types.TensorType(tf.float32, [None, 1]))),
        placements.CLIENTS)

    metrics_type = computation_types.FederatedType(
        collections.OrderedDict(
            broadcast=(),
            aggregation=(),
            train=collections.OrderedDict(
                loss=computation_types.TensorType(tf.float32),
                num_examples=computation_types.TensorType(tf.int32))),
        placements.SERVER)

    self.assertEqual(
        str(iterative_process.next.type_signature),
        str(
            computation_types.FunctionType(
                parameter=collections.OrderedDict(
                    server_state=server_state_type,
                    federated_dataset=dataset_type,
                ),
                result=(server_state_type, metrics_type))))
  def test_state_with_new_model_weights(self):
    trainable = [np.array([1.0, 2.0]), np.array([[1.0]])]
    non_trainable = [np.array(1)]
    state = anonymous_tuple.from_container(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=trainable, non_trainable=non_trainable),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0)),
        recursive=True)

    new_state = optimizer_utils.state_with_new_model_weights(
        state,
        trainable_weights=[np.array([3.0, 3.0]),
                           np.array([[3.0]])],
        non_trainable_weights=[np.array(3)])
    self.assertAllClose(
        new_state.model.trainable,
        [np.array([3.0, 3.0]), np.array([[3.0]])])
    self.assertAllClose(new_state.model.non_trainable, [3])

    with self.assertRaisesRegex(TypeError, 'tensor type'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([[3]])],
          non_trainable_weights=[np.array(3.0)])

    with self.assertRaisesRegex(TypeError, 'tensor type'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([3.0])],
          non_trainable_weights=[np.array(3)])

    with self.assertRaisesRegex(TypeError, 'different lengths'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0])],
          non_trainable_weights=[np.array(3)])

    with self.assertRaisesRegex(TypeError, 'cannot be handled'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights={'a': np.array([3.0, 3.0])},
          non_trainable_weights=[np.array(3)])
예제 #7
0
  def test_state_with_new_model_weights(self):
    trainable = [('b', np.array([1.0, 2.0])), ('a', np.array([[1.0]]))]
    non_trainable = [('c', np.array(1))]
    state = anonymous_tuple.from_container(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=collections.OrderedDict(trainable),
                non_trainable=collections.OrderedDict(non_trainable)),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0)),
        recursive=True)

    new_state = optimizer_utils.state_with_new_model_weights(
        state,
        trainable_weights=[np.array([3.0, 3.0]),
                           np.array([[3.0]])],
        non_trainable_weights=[np.array(3)])
    self.assertEqual(list(new_state.model.trainable.keys()), ['b', 'a'])
    self.assertEqual(list(new_state.model.non_trainable.keys()), ['c'])
    self.assertAllClose(new_state.model.trainable['b'], [3.0, 3.0])
    self.assertAllClose(new_state.model.trainable['a'], [[3.0]])
    self.assertAllClose(new_state.model.non_trainable['c'], 3)

    with self.assertRaisesRegexp(ValueError, 'dtype'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([[3]])],
          non_trainable_weights=[np.array(3.0)])

    with self.assertRaisesRegexp(ValueError, 'shape'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([3.0])],
          non_trainable_weights=[np.array(3)])

    with self.assertRaisesRegexp(ValueError, 'Lengths differ'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0])],
          non_trainable_weights=[np.array(3)])
예제 #8
0
    def test_state_with_new_model_weights_failure(self, new_trainable,
                                                  new_non_trainable,
                                                  expected_err_msg):
        trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)]
        non_trainable = [np.array(1), b'bytes type', 5, 2.0]
        state = optimizer_utils.ServerState(
            model=model_utils.ModelWeights(trainable=trainable,
                                           non_trainable=non_trainable),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0))

        new_trainable = trainable if new_trainable is None else new_trainable
        non_trainable = non_trainable if new_non_trainable is None else non_trainable

        with self.assertRaisesRegex(TypeError, expected_err_msg):
            optimizer_utils.state_with_new_model_weights(
                state,
                trainable_weights=new_trainable,
                non_trainable_weights=new_non_trainable)
예제 #9
0
    def test_state_with_model_weights_success(self):
        trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)]
        non_trainable = [np.array(1), b'bytes type', 5, 2.0]

        new_trainable = [np.array([3.0, 3.0]), np.array([[3.0]]), np.int64(4)]
        new_non_trainable = [np.array(3), b'bytes check', 6, 3.0]

        state = optimizer_utils.ServerState(
            model=model_utils.ModelWeights(trainable=trainable,
                                           non_trainable=non_trainable),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0))

        new_state = optimizer_utils.state_with_new_model_weights(
            state,
            trainable_weights=new_trainable,
            non_trainable_weights=new_non_trainable)
        self.assertAllClose(new_state.model.trainable, new_trainable)
        self.assertEqual(new_state.model.non_trainable, new_non_trainable)