예제 #1
0
def make_lstm_mpo_agent(env_spec: specs.EnvironmentSpec, logger: Logger,
                        hyperparams: Dict, checkpoint_path: str):
    params = DEFAULT_PARAMS.copy()
    params.update(hyperparams)
    action_size = np.prod(env_spec.actions.shape, dtype=int).item()
    policy_network = snt.Sequential([
        networks.LayerNormMLP(
            layer_sizes=[*params.pop('policy_layers'), action_size]),
        networks.MultivariateNormalDiagHead(num_dimensions=action_size)
    ])

    critic_network = snt.Sequential([
        networks.CriticMultiplexer(critic_network=networks.LayerNormMLP(
            layer_sizes=[*params.pop('critic_layers'), 1]))
    ])

    observation_network = networks.DeepRNN([
        networks.LayerNormMLP(layer_sizes=params.pop('observation_layers')),
        networks.LSTM(hidden_size=200)
    ])

    loss_param_keys = list(
        filter(lambda key: key.startswith('loss_'), params.keys()))
    loss_params = dict([(k.replace('loss_', ''), params.pop(k))
                        for k in loss_param_keys])
    policy_loss_module = losses.MPO(**loss_params)

    # Create a replay server to add data to.

    # Make sure observation network is a Sonnet Module.
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    # Create optimizers.
    policy_optimizer = Adam(params.pop('policy_lr'))
    critic_optimizer = Adam(params.pop('critic_lr'))

    actor = RecurrentActor(
        networks.DeepRNN([
            observation_network, policy_network,
            networks.StochasticModeHead()
        ]))

    # The learner updates the parameters (and initializes them).
    return RecurrentMPO(environment_spec=env_spec,
                        policy_network=policy_network,
                        critic_network=critic_network,
                        observation_network=observation_network,
                        policy_loss_module=policy_loss_module,
                        policy_optimizer=policy_optimizer,
                        critic_optimizer=critic_optimizer,
                        logger=logger,
                        checkpoint_path=checkpoint_path,
                        **params), actor
예제 #2
0
파일: learning.py 프로젝트: vballoli/acme
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        num_samples: int,
        target_policy_update_period: int,
        target_critic_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = tf.identity,
        target_observation_network: types.TensorTransformation = tf.identity,
        policy_loss_module: Optional[snt.Module] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        dual_optimizer: Optional[snt.Optimizer] = None,
        clipping: bool = True,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._num_samples = num_samples
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_policy_update_period = target_policy_update_period
        self._target_critic_update_period = target_critic_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self._policy_loss_module = policy_loss_module or losses.MPO(
            epsilon=1e-1,
            epsilon_penalty=1e-3,
            epsilon_mean=1e-3,
            epsilon_stddev=1e-6,
            init_log_temperature=1.,
            init_log_alpha_mean=1.,
            init_log_alpha_stddev=10.)

        # Create the optimizers.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='dmpo_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'dual_optimizer': self._dual_optimizer,
                    'policy_loss_module': self._policy_loss_module,
                    'num_steps': self._num_steps,
                })

            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save={
                    'policy':
                    snt.Sequential([
                        self._target_observation_network,
                        self._target_policy_network
                    ]),
                })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None