def make_lstm_mpo_agent(env_spec: specs.EnvironmentSpec, logger: Logger, hyperparams: Dict, checkpoint_path: str): params = DEFAULT_PARAMS.copy() params.update(hyperparams) action_size = np.prod(env_spec.actions.shape, dtype=int).item() policy_network = snt.Sequential([ networks.LayerNormMLP( layer_sizes=[*params.pop('policy_layers'), action_size]), networks.MultivariateNormalDiagHead(num_dimensions=action_size) ]) critic_network = snt.Sequential([ networks.CriticMultiplexer(critic_network=networks.LayerNormMLP( layer_sizes=[*params.pop('critic_layers'), 1])) ]) observation_network = networks.DeepRNN([ networks.LayerNormMLP(layer_sizes=params.pop('observation_layers')), networks.LSTM(hidden_size=200) ]) loss_param_keys = list( filter(lambda key: key.startswith('loss_'), params.keys())) loss_params = dict([(k.replace('loss_', ''), params.pop(k)) for k in loss_param_keys]) policy_loss_module = losses.MPO(**loss_params) # Create a replay server to add data to. # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create optimizers. policy_optimizer = Adam(params.pop('policy_lr')) critic_optimizer = Adam(params.pop('critic_lr')) actor = RecurrentActor( networks.DeepRNN([ observation_network, policy_network, networks.StochasticModeHead() ])) # The learner updates the parameters (and initializes them). return RecurrentMPO(environment_spec=env_spec, policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, policy_loss_module=policy_loss_module, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, logger=logger, checkpoint_path=checkpoint_path, **params), actor
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, num_samples: int, target_policy_update_period: int, target_critic_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = tf.identity, target_observation_network: types.TensorTransformation = tf.identity, policy_loss_module: Optional[snt.Module] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, dual_optimizer: Optional[snt.Optimizer] = None, clipping: bool = True, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._num_samples = num_samples self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_policy_update_period = target_policy_update_period self._target_critic_update_period = target_critic_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._policy_loss_module = policy_loss_module or losses.MPO( epsilon=1e-1, epsilon_penalty=1e-3, epsilon_mean=1e-3, epsilon_stddev=1e-6, init_log_temperature=1., init_log_alpha_mean=1., init_log_alpha_stddev=10.) # Create the optimizers. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='dmpo_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'dual_optimizer': self._dual_optimizer, 'policy_loss_module': self._policy_loss_module, 'num_steps': self._num_steps, }) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': snt.Sequential([ self._target_observation_network, self._target_policy_network ]), }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None