def _maybe_convert_to_spec(p): if isinstance(p, distribution_utils.Params): return _convert_to_spec_and_remove_singleton_batch_dim( p, outer_ndim) elif tf.is_tensor(p): return nest_utils.remove_singleton_batch_spec_dim( tf.type_spec_from_value(p), outer_ndim=outer_ndim) else: return p
def _calc_unbatched_spec(x): if isinstance(x, tfp.distributions.Distribution): parameters = distribution_utils.get_parameters(x) parameter_specs = _convert_to_spec_and_remove_singleton_batch_dim( parameters, outer_ndim=outer_ndim) return distribution_utils.DistributionSpecV2( event_shape=x.event_shape, dtype=x.dtype, parameters=parameter_specs) else: return nest_utils.remove_singleton_batch_spec_dim( tf.type_spec_from_value(x), outer_ndim=outer_ndim)
def remove_singleton_batch_spec_dim(t): # Convert tensor to its type-spec, and remove the batch dimension # from the spec. spec = tf.type_spec_from_value(t) return nest_utils.remove_singleton_batch_spec_dim(spec, outer_ndim=1)
def _calc_unbatched_spec(x): if isinstance(x, tfp.distributions.Distribution): return None else: return nest_utils.remove_singleton_batch_spec_dim( tf.type_spec_from_value(x), outer_ndim=1)
def create_variables(module: typing.Union[Network, tf.keras.layers.Layer], input_spec: typing.Optional[types.NestedTensorSpec] = None, **kwargs: typing.Any) -> types.NestedTensorSpec: """Create variables in `module` given `input_spec`; return `output_spec`. Here `module` can be a `Network`, and we will soon also support Keras layers (and possibly Sonnet layers). Args: module: The instance we would like to create layers on. input_spec: The input spec (excluding batch dimensions). **kwargs: Extra arguments to `module.__call__`, e.g. `training=True`. Returns: Output specs, a nested `tf.TypeSpec` describing the output signature. """ # NOTE(ebrevdo): As a side effect, for generic keras Layers (not Networks) # this method stores new hidden properties in `module`: # `_network_output_spec`, `_network_state_spec`, `_merged_output_and_state` # - which internal TF-Agents libraries make use of. if isinstance(module, Network): return module.create_variables(input_spec, **kwargs) # Generic keras layer if input_spec is None: raise ValueError( "Module is a Keras layer; an input_spec is required but saw " "None: {}".format(module)) maybe_spec = getattr(module, "_network_output_spec", None) if maybe_spec is not None: return maybe_spec # Has state outputs - so expect that a state input is required, # and output[1:] are output states. recurrent_layer = getattr(module, "get_initial_state", None) is not None # Required input rank outer_ndim = _get_input_outer_ndim(module, input_spec) random_input = tensor_spec.sample_spec_nest( input_spec, outer_dims=(1,) * outer_ndim) if recurrent_layer: state = module.get_initial_state(random_input) state_spec = tf.nest.map_structure( lambda s: nest_utils.remove_singleton_batch_spec_dim( # pylint: disable=g-long-lambda tf.type_spec_from_value(s), outer_ndim=1), state) outputs = module(random_input, state, **kwargs) # tf.keras.layers.{LSTM,RNN,GRU} with this return_state==True # return outputs of the form [output, state1, state2, ...] # # While tf.keras.layers.{LSTMCell, ...} return # (output, [state1, state2,...]). layer_config = module.get_config() merged_output_and_state = layer_config.get("return_state", False) if isinstance(module, recurrent.RNN): if not merged_output_and_state: # This is an RNN layer that doesn't return state. Excludes individual # cells. raise ValueError( "Provided a Keras RNN layer with return_state==False. " "This configuration is not supported. Layer: {}".format(module)) if not layer_config.get("return_sequences", False): raise ValueError( "Provided a Keras RNN layer with return_sequences==False. " "This configuration is not supported. Layer: {}".format(module)) output = outputs[0] else: output = module(random_input, **kwargs) state_spec = () merged_output_and_state = False def _calc_unbatched_spec(x): if isinstance(x, tfp.distributions.Distribution): return None else: return nest_utils.remove_singleton_batch_spec_dim( tf.type_spec_from_value(x), outer_ndim=outer_ndim) # pylint: disable=protected-access module._network_output_spec = tf.nest.map_structure(_calc_unbatched_spec, output) module._network_state_spec = state_spec module._merged_output_and_state = merged_output_and_state return module._network_output_spec