示例#1
0
    def __init__(self, comp_pb: pb.Computation,
                 type_spec: computation_types.FunctionType,
                 backend: xla_client.Client):
        """Creates this callable for a given computation, type, and backend.

    Args:
      comp_pb: An instance of `pb.Computation`.
      type_spec: An instance of `computation_types.FunctionType`.
      backend: An instance of `xla_client.Client`.

    Raises:
      ValueError: if the arguments are invalid.
    """
        py_typecheck.check_type(comp_pb, pb.Computation)
        py_typecheck.check_type(type_spec, computation_types.FunctionType)
        py_typecheck.check_type(backend, xla_client.Client)
        which_computation = comp_pb.WhichOneof('computation')
        if which_computation != 'xla':
            raise ValueError(
                'Unsupported computation type: {}'.format(which_computation))
        xla_comp = xla_serialization.unpack_xla_computation(
            comp_pb.xla.hlo_module)
        compile_options = xla_client.CompileOptions()
        compile_options.parameter_is_tupled_arguments = True
        self._executable = backend.compile(xla_comp, compile_options)
        self._inverted_parameter_tensor_indexes = list(
            np.argsort(_binding_to_tensor_indexes(comp_pb.xla.parameter)))
        self._result_tensor_indexes = _binding_to_tensor_indexes(
            comp_pb.xla.result)
        self._type_signature = type_spec
        self._backend = backend
示例#2
0
def _extract_call_ops(comp: pb.Computation) -> Iterator[tf.compat.v1.NodeDef]:
  computation_oneof = comp.WhichOneof('computation')
  if computation_oneof != 'tensorflow':
    raise TypeError('`prune_tensorflow_proto` only accepts `Computation` '
                    'protos of the "tensorflow" variety; you have passed '
                    'one of variety {}.'.format(computation_oneof))
  graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
  all_nodes = itertools.chain(graph_def.node,
                              *[f.node_def for f in graph_def.library.function])
  for node in all_nodes:
    if node.op in tensorflow_computation_transformations.CALL_OPS:
      yield node
示例#3
0
def _to_computation_internal_rep(*, value: pb.Computation,
                                 tf_function_cache: MutableMapping[str, Any],
                                 type_spec: computation_types.StructType,
                                 device: tf.config.LogicalDevice):
  """Converts a `pb.Computation` to a `tf.function`."""
  key = (value.SerializeToString(), str(type_spec),
         device.name if device else None)
  cached_fn = tf_function_cache.get(key)
  if cached_fn is not None:
    return cached_fn
  embedded_fn = embed_tensorflow_computation(value, type_spec, device)
  tf_function_cache[key] = embedded_fn
  return embedded_fn
示例#4
0
def _check_ops(proto: computation_pb2.Computation,
               allowed_op_names: Optional[FrozenSet[str]] = None,
               disallowed_op_names: Optional[FrozenSet[str]] = None):
    """Checks the ops in the TensorFlow computation.

  If allowed_op_names is specified, then _check_ops checks the incoming proto
  contains only ops in the set. On the other hand, if disallowed_op_names is
  specified, then _check_ops checks the proto contains no ops contained in the
  set. One of the two op set arguments must be non-empty, and if both are, then
  allowed_op_names takes precedent.

  Args:
    proto: Instance of `computation_pb2.Computation` with the `tensorflow` field
      populated.
    allowed_op_names: Set of allowed op names.
    disallowed_op_names: Set of disallowed op names.

  Raises:
    DisallowedOpInTensorFlowComputationError: If the computation contains a
      disallowed op.
    RuntimeError: If both allowed_op_names and disallowed_op_names are empty.
  """
    py_typecheck.check_type(proto, computation_pb2.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError('`prune_tensorflow_proto` only accepts `Computation` '
                        'protos of the "tensorflow" variety; you have passed '
                        'one of variety {}.'.format(computation_oneof))
    graph_def = serialization_utils.unpack_graph_def(
        proto.tensorflow.graph_def)
    all_nodes = itertools.chain(
        graph_def.node, *[f.node_def for f in graph_def.library.function])
    found_disallowed_op_names = set()

    if allowed_op_names:
        for node in all_nodes:
            if node.op not in allowed_op_names:
                found_disallowed_op_names.add(node.op)
    elif disallowed_op_names:
        for node in all_nodes:
            if node.op in disallowed_op_names:
                found_disallowed_op_names.add(node.op)
    else:
        raise RuntimeError(
            'One of allowed_op_names or disallowed_op_names must be non-empty')

    if found_disallowed_op_names:
        found_disallowed_op_names_str = ', '.join(found_disallowed_op_names)
        raise DisallowedOpInTensorFlowComputationError(
            f'Found disallowed ops: {found_disallowed_op_names_str}')
示例#5
0
def _to_computation_internal_rep(*, value: pb.Computation,
                                 tf_function_cache: MutableMapping[str, Any],
                                 type_spec: computation_types.StructType,
                                 device: tf.config.LogicalDevice):
  """Converts a `pb.Computation` to a `tf.function`."""
  if value.tensorflow.cache_key.id:
    logging.debug('Using value id for cache key: %s',
                  value.tensorflow.cache_key.id)
    key = (value.tensorflow.cache_key.id,
           type_serialization.serialize_type(type_spec).SerializeToString(
               deterministic=True), device.name if device else None)
  else:
    logging.debug('Using hash of graph_def for cache key')
    key = (value.SerializeToString(deterministic=True),
           type_serialization.serialize_type(type_spec).SerializeToString(
               deterministic=True), device.name if device else None)
  cached_fn = tf_function_cache.get(key)
  if cached_fn is not None:
    return cached_fn
  embedded_fn = embed_tensorflow_computation(value, type_spec, device)
  tf_function_cache[key] = embedded_fn
  return embedded_fn
  async def _evaluate(
      self,
      comp: pb.Computation,
      scope=ReferenceResolvingExecutorScope({}),
  ) -> ReferenceResolvingExecutorValue:
    """Transforms `pb.Computation` into a `ReferenceResolvingExecutorValue`.

    Args:
      comp: An instance of `pb.Computation` to process.
      scope: A `ReferenceResolvingExecutorScope` to process it in. If
        omitted,defaults to an empty scope.

    Returns:
      An instance of `ReferenceResolvingExecutorValue`.
    """
    py_typecheck.check_type(comp, pb.Computation)
    py_typecheck.check_type(scope, ReferenceResolvingExecutorScope)
    which_computation = comp.WhichOneof('computation')
    if which_computation in [
        'tensorflow', 'intrinsic', 'data', 'placement', 'xla'
    ]:
      # nothing interesting here-- forward the creation to the child executor
      return await self._evaluate_to_delegate(comp, scope)
    elif which_computation == 'lambda':
      return await self._evaluate_lambda(comp, scope)
    elif which_computation == 'reference':
      return await self._evaluate_reference(comp, scope)
    elif which_computation == 'call':
      return await self._evaluate_call(comp, scope)
    elif which_computation == 'selection':
      return await self._evaluate_selection(comp, scope)
    elif which_computation == 'struct':
      return await self._evaluate_struct(comp, scope)
    elif which_computation == 'block':
      return await self._evaluate_block(comp, scope)
    else:
      raise NotImplementedError(
          'Unsupported computation type "{}".'.format(which_computation))
  async def _evaluate(
      self,
      comp: pb.Computation,
      scope=ReferenceResolvingExecutorScope({}),
  ) -> ReferenceResolvingExecutorValue:
    """Transforms `pb.Computation` into a `ReferenceResolvingExecutorValue`.

    Args:
      comp: An instance of `pb.Computation` to process.
      scope: A `ReferenceResolvingExecutorScope` to process it in.
        If omitted,defaults to an empty scope.

    Returns:
      An instance of `ReferenceResolvingExecutorValue`.
    """
    py_typecheck.check_type(comp, pb.Computation)
    py_typecheck.check_type(scope, ReferenceResolvingExecutorScope)
    which_computation = comp.WhichOneof('computation')
    if which_computation in ['tensorflow', 'intrinsic', 'data', 'placement']:
      # nothing interesting here-- forward the creation to the child executor
      return ReferenceResolvingExecutorValue(
          await self._target_executor.create_value(
              comp, type_serialization.deserialize_type(comp.type)))
    elif which_computation == 'lambda':
      type_spec = type_serialization.deserialize_type(comp.type)
      return ReferenceResolvingExecutorValue(
          ScopedLambda(comp, scope), type_spec=type_spec)
    elif which_computation == 'reference':
      return scope.resolve_reference(comp.reference.name)
    elif which_computation == 'call':
      func = self._evaluate(comp.call.function, scope=scope)

      async def get_arg():
        if comp.call.argument.WhichOneof('computation') is not None:
          return await self._evaluate(comp.call.argument, scope=scope)
        else:
          return None

      func, arg = await asyncio.gather(func, get_arg())
      return await self.create_call(func, arg=arg)
    elif which_computation == 'selection':
      which_selection = comp.selection.WhichOneof('selection')
      source = await self._evaluate(comp.selection.source, scope=scope)
      return await self.create_selection(
          source, **{which_selection: getattr(comp.selection, which_selection)})
    elif which_computation == 'tuple':
      names = [str(e.name) if e.name else None for e in comp.tuple.element]
      values = [
          self._evaluate(e.value, scope=scope) for e in comp.tuple.element
      ]
      values = await asyncio.gather(*values)
      return await self.create_tuple(
          anonymous_tuple.AnonymousTuple(zip(names, values)))
    elif which_computation == 'block':
      for loc in comp.block.local:
        value = await self._evaluate(loc.value, scope)
        scope = ReferenceResolvingExecutorScope({loc.name: value}, scope)
      return await self._evaluate(comp.block.result, scope)
    else:
      raise NotImplementedError(
          'Unsupported computation type "{}".'.format(which_computation))
def _hash_proto(comp: pb.Computation) -> int:
  """Hash the `pb.Computation` for use as a cache key."""
  return hash(comp.SerializeToString())