Пример #1
0
    def __init__(self, comp_pb: pb.Computation,
                 type_spec: computation_types.FunctionType,
                 backend: xla_client.Client):
        """Creates this callable for a given computation, type, and backend.

    Args:
      comp_pb: An instance of `pb.Computation`.
      type_spec: An instance of `computation_types.FunctionType`.
      backend: An instance of `xla_client.Client`.

    Raises:
      ValueError: if the arguments are invalid.
    """
        py_typecheck.check_type(comp_pb, pb.Computation)
        py_typecheck.check_type(type_spec, computation_types.FunctionType)
        py_typecheck.check_type(backend, xla_client.Client)
        which_computation = comp_pb.WhichOneof('computation')
        if which_computation != 'xla':
            raise ValueError(
                'Unsupported computation type: {}'.format(which_computation))
        xla_comp = xla_serialization.unpack_xla_computation(
            comp_pb.xla.hlo_module)
        compile_options = xla_client.CompileOptions()
        compile_options.parameter_is_tupled_arguments = True
        self._executable = backend.compile(xla_comp, compile_options)
        self._inverted_parameter_tensor_indexes = list(
            np.argsort(_binding_to_tensor_indexes(comp_pb.xla.parameter)))
        self._result_tensor_indexes = _binding_to_tensor_indexes(
            comp_pb.xla.result)
        self._type_signature = type_spec
        self._backend = backend
Пример #2
0
  def __init__(self, comp_pb: pb.Computation,
               type_spec: computation_types.FunctionType,
               backend: xla_client.Client):
    """Creates this callable for a given computation, type, and backend.

    Args:
      comp_pb: An instance of `pb.Computation`.
      type_spec: An instance of `computation_types.FunctionType`.
      backend: An instance of `xla_client.Client`.
    """
    py_typecheck.check_type(comp_pb, pb.Computation)
    py_typecheck.check_type(type_spec, computation_types.FunctionType)
    py_typecheck.check_type(backend, xla_client.Client)
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    compile_options = xla_client.CompileOptions()
    compile_options.parameter_is_tupled_arguments = True
    self._executable = backend.compile(xla_comp, compile_options)
    self._type_signature = type_spec
    self._backend = backend