Ejemplo n.º 1
0
 def _maybe_convert_to_spec(p):
     if isinstance(p, distribution_utils.Params):
         return _convert_to_spec_and_remove_singleton_batch_dim(
             p, outer_ndim)
     elif tf.is_tensor(p):
         return nest_utils.remove_singleton_batch_spec_dim(
             tf.type_spec_from_value(p), outer_ndim=outer_ndim)
     else:
         return p
Ejemplo n.º 2
0
 def _calc_unbatched_spec(x):
     if isinstance(x, tfp.distributions.Distribution):
         parameters = distribution_utils.get_parameters(x)
         parameter_specs = _convert_to_spec_and_remove_singleton_batch_dim(
             parameters, outer_ndim=outer_ndim)
         return distribution_utils.DistributionSpecV2(
             event_shape=x.event_shape,
             dtype=x.dtype,
             parameters=parameter_specs)
     else:
         return nest_utils.remove_singleton_batch_spec_dim(
             tf.type_spec_from_value(x), outer_ndim=outer_ndim)
Ejemplo n.º 3
0
 def remove_singleton_batch_spec_dim(t):
     # Convert tensor to its type-spec, and remove the batch dimension
     # from the spec.
     spec = tf.type_spec_from_value(t)
     return nest_utils.remove_singleton_batch_spec_dim(spec,
                                                       outer_ndim=1)
Ejemplo n.º 4
0
 def _calc_unbatched_spec(x):
     if isinstance(x, tfp.distributions.Distribution):
         return None
     else:
         return nest_utils.remove_singleton_batch_spec_dim(
             tf.type_spec_from_value(x), outer_ndim=1)
Ejemplo n.º 5
0
def create_variables(module: typing.Union[Network, tf.keras.layers.Layer],
                     input_spec: typing.Optional[types.NestedTensorSpec] = None,
                     **kwargs: typing.Any) -> types.NestedTensorSpec:
  """Create variables in `module` given `input_spec`; return `output_spec`.

  Here `module` can be a `Network`, and we will soon also support Keras
  layers (and possibly Sonnet layers).

  Args:
    module: The instance we would like to create layers on.
    input_spec: The input spec (excluding batch dimensions).
    **kwargs: Extra arguments to `module.__call__`, e.g. `training=True`.

  Returns:
    Output specs, a nested `tf.TypeSpec` describing the output signature.
  """
  # NOTE(ebrevdo): As a side effect, for generic keras Layers (not Networks)
  # this method stores new hidden properties in `module`:
  # `_network_output_spec`, `_network_state_spec`, `_merged_output_and_state`
  # - which internal TF-Agents libraries make use of.
  if isinstance(module, Network):
    return module.create_variables(input_spec, **kwargs)

  # Generic keras layer
  if input_spec is None:
    raise ValueError(
        "Module is a Keras layer; an input_spec is required but saw "
        "None: {}".format(module))

  maybe_spec = getattr(module, "_network_output_spec", None)
  if maybe_spec is not None:
    return maybe_spec

  # Has state outputs - so expect that a state input is required,
  # and output[1:] are output states.
  recurrent_layer = getattr(module, "get_initial_state", None) is not None

  # Required input rank
  outer_ndim = _get_input_outer_ndim(module, input_spec)

  random_input = tensor_spec.sample_spec_nest(
      input_spec, outer_dims=(1,) * outer_ndim)

  if recurrent_layer:
    state = module.get_initial_state(random_input)
    state_spec = tf.nest.map_structure(
        lambda s: nest_utils.remove_singleton_batch_spec_dim(  # pylint: disable=g-long-lambda
            tf.type_spec_from_value(s),
            outer_ndim=1),
        state)
    outputs = module(random_input, state, **kwargs)
    # tf.keras.layers.{LSTM,RNN,GRU} with this return_state==True
    # return outputs of the form [output, state1, state2, ...]
    #
    # While tf.keras.layers.{LSTMCell, ...} return
    # (output, [state1, state2,...]).
    layer_config = module.get_config()
    merged_output_and_state = layer_config.get("return_state", False)
    if isinstance(module, recurrent.RNN):
      if not merged_output_and_state:
        # This is an RNN layer that doesn't return state.  Excludes individual
        # cells.
        raise ValueError(
            "Provided a Keras RNN layer with return_state==False. "
            "This configuration is not supported.  Layer: {}".format(module))
      if not layer_config.get("return_sequences", False):
        raise ValueError(
            "Provided a Keras RNN layer with return_sequences==False. "
            "This configuration is not supported.  Layer: {}".format(module))
    output = outputs[0]
  else:
    output = module(random_input, **kwargs)
    state_spec = ()
    merged_output_and_state = False

  def _calc_unbatched_spec(x):
    if isinstance(x, tfp.distributions.Distribution):
      return None
    else:
      return nest_utils.remove_singleton_batch_spec_dim(
          tf.type_spec_from_value(x), outer_ndim=outer_ndim)

  # pylint: disable=protected-access
  module._network_output_spec = tf.nest.map_structure(_calc_unbatched_spec,
                                                      output)
  module._network_state_spec = state_spec
  module._merged_output_and_state = merged_output_and_state

  return module._network_output_spec