Example #1
0
def _to_computation_internal_rep(*, value: pb.Computation,
                                 tf_function_cache: MutableMapping[str, Any],
                                 type_spec: computation_types.StructType,
                                 device: tf.config.LogicalDevice):
  """Converts a `pb.Computation` to a `tf.function`."""
  key = (value.SerializeToString(), str(type_spec),
         device.name if device else None)
  cached_fn = tf_function_cache.get(key)
  if cached_fn is not None:
    return cached_fn
  embedded_fn = embed_tensorflow_computation(value, type_spec, device)
  tf_function_cache[key] = embedded_fn
  return embedded_fn
Example #2
0
def _to_computation_internal_rep(*, value: pb.Computation,
                                 tf_function_cache: MutableMapping[str, Any],
                                 type_spec: computation_types.StructType,
                                 device: tf.config.LogicalDevice):
  """Converts a `pb.Computation` to a `tf.function`."""
  if value.tensorflow.cache_key.id:
    logging.debug('Using value id for cache key: %s',
                  value.tensorflow.cache_key.id)
    key = (value.tensorflow.cache_key.id,
           type_serialization.serialize_type(type_spec).SerializeToString(
               deterministic=True), device.name if device else None)
  else:
    logging.debug('Using hash of graph_def for cache key')
    key = (value.SerializeToString(deterministic=True),
           type_serialization.serialize_type(type_spec).SerializeToString(
               deterministic=True), device.name if device else None)
  cached_fn = tf_function_cache.get(key)
  if cached_fn is not None:
    return cached_fn
  embedded_fn = embed_tensorflow_computation(value, type_spec, device)
  tf_function_cache[key] = embedded_fn
  return embedded_fn
def _hash_proto(comp: pb.Computation) -> int:
  """Hash the `pb.Computation` for use as a cache key."""
  return hash(comp.SerializeToString())