コード例 #1
0
ファイル: executor.py プロジェクト: isabella232/federated
  def __init__(self, backend=None):
    """Creates a new instance of an XLA executor.

    Args:
      backend: An optional name of a local XLA backend.
    """
    self._backend = xla_client.get_local_backend(backend)
コード例 #2
0
ファイル: executor.py プロジェクト: xingzhis/federated
    def __init__(self, device=None):
        """Creates a new instance of an XLA executor.

    Args:
      device: An optional device name (currently unsupported; must be `None`).
    """
        if device is not None:
            raise ValueError(
                'Explicitly specifying a device is currently not supported.')
        self._backend = xla_client.get_local_backend(None)
コード例 #3
0
 def _run_comp(self, comp_pb, comp_type, arg=None):
     self.assertIsInstance(comp_pb, pb.Computation)
     self.assertIsInstance(comp_type, computation_types.FunctionType)
     backend = xla_client.get_local_backend(None)
     comp_callable = runtime.ComputationCallable(comp_pb, comp_type,
                                                 backend)
     arg_list = []
     if arg is not None:
         arg_list.append(arg)
     return comp_callable(*arg_list)
コード例 #4
0
ファイル: runtime_test.py プロジェクト: xingzhis/federated
 def test_computation_callable_return_one_number(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     backend = xla_client.get_local_backend(None)
     comp_callable = runtime.ComputationCallable(comp_pb, comp_type,
                                                 backend)
     self.assertIsInstance(comp_callable, runtime.ComputationCallable)
     self.assertEqual(str(comp_callable.type_signature), '( -> int32)')
     result = comp_callable()
     self.assertEqual(result, 10)
コード例 #5
0
ファイル: runtime_test.py プロジェクト: xingzhis/federated
 def test_computation_callable_add_two_numbers(self):
     builder = xla_client.XlaBuilder('comp')
     param = xla_client.ops.Parameter(
         builder, 0,
         xla_client.shape_from_pyval(
             tuple([np.array(0, dtype=np.int32)] * 2)))
     xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0),
                        xla_client.ops.GetTupleElement(param, 1))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType((np.int32, np.int32),
                                                np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1], comp_type)
     backend = xla_client.get_local_backend(None)
     comp_callable = runtime.ComputationCallable(comp_pb, comp_type,
                                                 backend)
     self.assertIsInstance(comp_callable, runtime.ComputationCallable)
     self.assertEqual(str(comp_callable.type_signature),
                      '(<int32,int32> -> int32)')
     result = comp_callable(
         structure.Struct([(None, np.int32(2)), (None, np.int32(3))]))
     self.assertEqual(result, 5)
コード例 #6
0
 def setUp(self):
     super(ExecutorTest, self).setUp()
     self._backend = xla_client.get_local_backend()