コード例 #1
0
    def test_unweighted_federated_mean_in_xla_execution_context(self):
        @computations.federated_computation(
            computation_types.FederatedType(np.float32, placements.CLIENTS))
        def comp(x):
            return intrinsics.federated_mean(x)

        execution_contexts.set_local_execution_context()
        self.assertEqual(comp([1.0, 2.0, 3.0]), 2.0)
コード例 #2
0
    def test_federated_sum_in_xla_execution_context(self):
        @computations.federated_computation(
            computation_types.FederatedType(np.int32, placements.CLIENTS))
        def comp(x):
            return intrinsics.federated_sum(x)

        execution_contexts.set_local_execution_context()
        self.assertEqual(comp([1, 2, 3]), 6)
コード例 #3
0
 def test_set_local_execution_context_and_run_simple_xla_computation(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     ctx_stack = context_stack_impl.context_stack
     comp = computation_impl.ComputationImpl(comp_pb, ctx_stack)
     execution_contexts.set_local_execution_context()
     self.assertEqual(comp(), 10)
コード例 #4
0
                                       high=9,
                                       size=(50, ),
                                       dtype=np.int32)
            return collections.OrderedDict([('pixels', pixels),
                                            ('labels', labels)])

        model_type = collections.OrderedDict([
            ('weights', computation_types.TensorType(np.float32, (784, 10))),
            ('bias', computation_types.TensorType(np.float32, (10, )))
        ])

        def loss(model, batch):
            y = jax.nn.softmax(
                jax.numpy.add(
                    jax.numpy.matmul(batch['pixels'], model['weights']),
                    model['bias']))
            targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1),
                                     10)
            return -jax.numpy.mean(
                jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

        trainer = jax_components.build_jax_federated_averaging_process(
            batch_type, model_type, loss, step_size=0.001)

        trainer.next(trainer.initialize(), [[random_batch()]])


if __name__ == '__main__':
    execution_contexts.set_local_execution_context()
    absltest.main()