def __init__(self, comp_pb: pb.Computation, type_spec: computation_types.FunctionType, backend: xla_client.Client): """Creates this callable for a given computation, type, and backend. Args: comp_pb: An instance of `pb.Computation`. type_spec: An instance of `computation_types.FunctionType`. backend: An instance of `xla_client.Client`. Raises: ValueError: if the arguments are invalid. """ py_typecheck.check_type(comp_pb, pb.Computation) py_typecheck.check_type(type_spec, computation_types.FunctionType) py_typecheck.check_type(backend, xla_client.Client) which_computation = comp_pb.WhichOneof('computation') if which_computation != 'xla': raise ValueError( 'Unsupported computation type: {}'.format(which_computation)) xla_comp = xla_serialization.unpack_xla_computation( comp_pb.xla.hlo_module) compile_options = xla_client.CompileOptions() compile_options.parameter_is_tupled_arguments = True self._executable = backend.compile(xla_comp, compile_options) self._inverted_parameter_tensor_indexes = list( np.argsort(_binding_to_tensor_indexes(comp_pb.xla.parameter))) self._result_tensor_indexes = _binding_to_tensor_indexes( comp_pb.xla.result) self._type_signature = type_spec self._backend = backend
def _extract_call_ops(comp: pb.Computation) -> Iterator[tf.compat.v1.NodeDef]: computation_oneof = comp.WhichOneof('computation') if computation_oneof != 'tensorflow': raise TypeError('`prune_tensorflow_proto` only accepts `Computation` ' 'protos of the "tensorflow" variety; you have passed ' 'one of variety {}.'.format(computation_oneof)) graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def) all_nodes = itertools.chain(graph_def.node, *[f.node_def for f in graph_def.library.function]) for node in all_nodes: if node.op in tensorflow_computation_transformations.CALL_OPS: yield node
def _to_computation_internal_rep(*, value: pb.Computation, tf_function_cache: MutableMapping[str, Any], type_spec: computation_types.StructType, device: tf.config.LogicalDevice): """Converts a `pb.Computation` to a `tf.function`.""" key = (value.SerializeToString(), str(type_spec), device.name if device else None) cached_fn = tf_function_cache.get(key) if cached_fn is not None: return cached_fn embedded_fn = embed_tensorflow_computation(value, type_spec, device) tf_function_cache[key] = embedded_fn return embedded_fn
def _check_ops(proto: computation_pb2.Computation, allowed_op_names: Optional[FrozenSet[str]] = None, disallowed_op_names: Optional[FrozenSet[str]] = None): """Checks the ops in the TensorFlow computation. If allowed_op_names is specified, then _check_ops checks the incoming proto contains only ops in the set. On the other hand, if disallowed_op_names is specified, then _check_ops checks the proto contains no ops contained in the set. One of the two op set arguments must be non-empty, and if both are, then allowed_op_names takes precedent. Args: proto: Instance of `computation_pb2.Computation` with the `tensorflow` field populated. allowed_op_names: Set of allowed op names. disallowed_op_names: Set of disallowed op names. Raises: DisallowedOpInTensorFlowComputationError: If the computation contains a disallowed op. RuntimeError: If both allowed_op_names and disallowed_op_names are empty. """ py_typecheck.check_type(proto, computation_pb2.Computation) computation_oneof = proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise TypeError('`prune_tensorflow_proto` only accepts `Computation` ' 'protos of the "tensorflow" variety; you have passed ' 'one of variety {}.'.format(computation_oneof)) graph_def = serialization_utils.unpack_graph_def( proto.tensorflow.graph_def) all_nodes = itertools.chain( graph_def.node, *[f.node_def for f in graph_def.library.function]) found_disallowed_op_names = set() if allowed_op_names: for node in all_nodes: if node.op not in allowed_op_names: found_disallowed_op_names.add(node.op) elif disallowed_op_names: for node in all_nodes: if node.op in disallowed_op_names: found_disallowed_op_names.add(node.op) else: raise RuntimeError( 'One of allowed_op_names or disallowed_op_names must be non-empty') if found_disallowed_op_names: found_disallowed_op_names_str = ', '.join(found_disallowed_op_names) raise DisallowedOpInTensorFlowComputationError( f'Found disallowed ops: {found_disallowed_op_names_str}')
def _to_computation_internal_rep(*, value: pb.Computation, tf_function_cache: MutableMapping[str, Any], type_spec: computation_types.StructType, device: tf.config.LogicalDevice): """Converts a `pb.Computation` to a `tf.function`.""" if value.tensorflow.cache_key.id: logging.debug('Using value id for cache key: %s', value.tensorflow.cache_key.id) key = (value.tensorflow.cache_key.id, type_serialization.serialize_type(type_spec).SerializeToString( deterministic=True), device.name if device else None) else: logging.debug('Using hash of graph_def for cache key') key = (value.SerializeToString(deterministic=True), type_serialization.serialize_type(type_spec).SerializeToString( deterministic=True), device.name if device else None) cached_fn = tf_function_cache.get(key) if cached_fn is not None: return cached_fn embedded_fn = embed_tensorflow_computation(value, type_spec, device) tf_function_cache[key] = embedded_fn return embedded_fn
async def _evaluate( self, comp: pb.Computation, scope=ReferenceResolvingExecutorScope({}), ) -> ReferenceResolvingExecutorValue: """Transforms `pb.Computation` into a `ReferenceResolvingExecutorValue`. Args: comp: An instance of `pb.Computation` to process. scope: A `ReferenceResolvingExecutorScope` to process it in. If omitted,defaults to an empty scope. Returns: An instance of `ReferenceResolvingExecutorValue`. """ py_typecheck.check_type(comp, pb.Computation) py_typecheck.check_type(scope, ReferenceResolvingExecutorScope) which_computation = comp.WhichOneof('computation') if which_computation in [ 'tensorflow', 'intrinsic', 'data', 'placement', 'xla' ]: # nothing interesting here-- forward the creation to the child executor return await self._evaluate_to_delegate(comp, scope) elif which_computation == 'lambda': return await self._evaluate_lambda(comp, scope) elif which_computation == 'reference': return await self._evaluate_reference(comp, scope) elif which_computation == 'call': return await self._evaluate_call(comp, scope) elif which_computation == 'selection': return await self._evaluate_selection(comp, scope) elif which_computation == 'struct': return await self._evaluate_struct(comp, scope) elif which_computation == 'block': return await self._evaluate_block(comp, scope) else: raise NotImplementedError( 'Unsupported computation type "{}".'.format(which_computation))
async def _evaluate( self, comp: pb.Computation, scope=ReferenceResolvingExecutorScope({}), ) -> ReferenceResolvingExecutorValue: """Transforms `pb.Computation` into a `ReferenceResolvingExecutorValue`. Args: comp: An instance of `pb.Computation` to process. scope: A `ReferenceResolvingExecutorScope` to process it in. If omitted,defaults to an empty scope. Returns: An instance of `ReferenceResolvingExecutorValue`. """ py_typecheck.check_type(comp, pb.Computation) py_typecheck.check_type(scope, ReferenceResolvingExecutorScope) which_computation = comp.WhichOneof('computation') if which_computation in ['tensorflow', 'intrinsic', 'data', 'placement']: # nothing interesting here-- forward the creation to the child executor return ReferenceResolvingExecutorValue( await self._target_executor.create_value( comp, type_serialization.deserialize_type(comp.type))) elif which_computation == 'lambda': type_spec = type_serialization.deserialize_type(comp.type) return ReferenceResolvingExecutorValue( ScopedLambda(comp, scope), type_spec=type_spec) elif which_computation == 'reference': return scope.resolve_reference(comp.reference.name) elif which_computation == 'call': func = self._evaluate(comp.call.function, scope=scope) async def get_arg(): if comp.call.argument.WhichOneof('computation') is not None: return await self._evaluate(comp.call.argument, scope=scope) else: return None func, arg = await asyncio.gather(func, get_arg()) return await self.create_call(func, arg=arg) elif which_computation == 'selection': which_selection = comp.selection.WhichOneof('selection') source = await self._evaluate(comp.selection.source, scope=scope) return await self.create_selection( source, **{which_selection: getattr(comp.selection, which_selection)}) elif which_computation == 'tuple': names = [str(e.name) if e.name else None for e in comp.tuple.element] values = [ self._evaluate(e.value, scope=scope) for e in comp.tuple.element ] values = await asyncio.gather(*values) return await self.create_tuple( anonymous_tuple.AnonymousTuple(zip(names, values))) elif which_computation == 'block': for loc in comp.block.local: value = await self._evaluate(loc.value, scope) scope = ReferenceResolvingExecutorScope({loc.name: value}, scope) return await self._evaluate(comp.block.result, scope) else: raise NotImplementedError( 'Unsupported computation type "{}".'.format(which_computation))
def _hash_proto(comp: pb.Computation) -> int: """Hash the `pb.Computation` for use as a cache key.""" return hash(comp.SerializeToString())