Beispiel #1
0
def load_policy_net(
    task_name: str,
    noise_level: float,
    dataset_path: str,
    environment_spec: specs.EnvironmentSpec,
    near_policy_dataset: bool = False,
    ):
    dataset_path = Path(dataset_path)
    if task_name.startswith("bsuite"):
        # BSuite tasks.
        bsuite_id = task_name[len("bsuite_"):] + "/0"
        path = bsuite_policy_path(
            bsuite_id, noise_level, near_policy_dataset, dataset_path)
        logging.info("Policy path: %s", path)
        policy_net = tf.saved_model.load(path)

        policy_noise_level = 0.1  # params["policy_noise_level"]
        observation_network = tf2_utils.to_sonnet_module(functools.partial(
            tf.reshape, shape=(-1,) + environment_spec.observations.shape))
        policy_net = snt.Sequential([
            observation_network,
            policy_net,
            # Uncomment this line to add action noise to the target policy.
            lambda q: trfl.epsilon_greedy(q, epsilon=policy_noise_level).sample(),
        ])
    elif task_name.startswith("dm_control"):
        # DM Control tasks.
        if near_policy_dataset:
            raise ValueError(
                "Near-policy dataset is not available for dm_control tasks.")
        dm_control_task = task_name[len("dm_control_"):]
        path = dm_control_policy_path(
            dm_control_task, noise_level, dataset_path)
        logging.info("Policy path: %s", path)
        policy_net = tf.saved_model.load(path)

        policy_noise_level = 0.2  # params["policy_noise_level"]
        observation_network = tf2_utils.to_sonnet_module(tf2_utils.batch_concat)
        policy_net = snt.Sequential([
            observation_network,
            policy_net,
            # Uncomment these two lines to add action noise to target policy.
            acme_utils.GaussianNoise(policy_noise_level),
            networks.ClipToSpec(environment_spec.actions),
        ])
    else:
        raise ValueError(f"task name {task_name} is unsupported.")
    return policy_net
Beispiel #2
0
 def test_scalar_output(self):
   model = tf2_utils.to_sonnet_module(tf.reduce_sum)
   input_spec = specs.Array(shape=(10,), dtype=np.float32)
   expected_spec = tf.TensorSpec(shape=(), dtype=tf.float32)
   output_spec = tf2_utils.create_variables(model, [input_spec])
   self.assertEqual(model.variables, ())
   self.assertEqual(output_spec, expected_spec)
Beispiel #3
0
 def test_none_output(self):
   model = tf2_utils.to_sonnet_module(lambda x: None)
   input_spec = specs.Array(shape=(10,), dtype=np.float32)
   expected_spec = None
   output_spec = tf2_utils.create_variables(model, [input_spec])
   self.assertEqual(model.variables, ())
   self.assertEqual(output_spec, expected_spec)
Beispiel #4
0
def make_lstm_mpo_agent(env_spec: specs.EnvironmentSpec, logger: Logger,
                        hyperparams: Dict, checkpoint_path: str):
    params = DEFAULT_PARAMS.copy()
    params.update(hyperparams)
    action_size = np.prod(env_spec.actions.shape, dtype=int).item()
    policy_network = snt.Sequential([
        networks.LayerNormMLP(
            layer_sizes=[*params.pop('policy_layers'), action_size]),
        networks.MultivariateNormalDiagHead(num_dimensions=action_size)
    ])

    critic_network = snt.Sequential([
        networks.CriticMultiplexer(critic_network=networks.LayerNormMLP(
            layer_sizes=[*params.pop('critic_layers'), 1]))
    ])

    observation_network = networks.DeepRNN([
        networks.LayerNormMLP(layer_sizes=params.pop('observation_layers')),
        networks.LSTM(hidden_size=200)
    ])

    loss_param_keys = list(
        filter(lambda key: key.startswith('loss_'), params.keys()))
    loss_params = dict([(k.replace('loss_', ''), params.pop(k))
                        for k in loss_param_keys])
    policy_loss_module = losses.MPO(**loss_params)

    # Create a replay server to add data to.

    # Make sure observation network is a Sonnet Module.
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    # Create optimizers.
    policy_optimizer = Adam(params.pop('policy_lr'))
    critic_optimizer = Adam(params.pop('critic_lr'))

    actor = RecurrentActor(
        networks.DeepRNN([
            observation_network, policy_network,
            networks.StochasticModeHead()
        ]))

    # The learner updates the parameters (and initializes them).
    return RecurrentMPO(environment_spec=env_spec,
                        policy_network=policy_network,
                        critic_network=critic_network,
                        observation_network=observation_network,
                        policy_loss_module=policy_loss_module,
                        policy_optimizer=policy_optimizer,
                        critic_optimizer=critic_optimizer,
                        logger=logger,
                        checkpoint_path=checkpoint_path,
                        **params), actor
Beispiel #5
0
 def __init__(
     self,
     policy_network: snt.Module,
     critic_network: snt.Module,
     observation_network: types.TensorTransformation,
 ):
   # This method is implemented (rather than added by the dataclass decorator)
   # in order to allow observation network to be passed as an arbitrary tensor
   # transformation rather than as a snt Module.
   # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation.
   self.policy_network = policy_network
   self.critic_network = critic_network
   self.observation_network = utils.to_sonnet_module(observation_network)
Beispiel #6
0
def make_default_networks(
    environment_spec: specs.EnvironmentSpec,
    *,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
    policy_init_scale: float = 0.7,
    critic_init_scale: float = 1e-3,
    critic_num_components: int = 5,
) -> Mapping[str, snt.Module]:
    """Creates networks used by the agent."""

    # Unpack the environment spec to get appropriate shapes, dtypes, etc.
    act_spec = environment_spec.actions
    obs_spec = environment_spec.observations
    num_dimensions = np.prod(act_spec.shape, dtype=int)

    # Create the observation network and make sure it's a Sonnet module.
    observation_network = tf2_utils.batch_concat
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    # Create the policy network.
    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            init_scale=policy_init_scale,
                                            use_tfd_independent=True)
    ])

    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(
            action_network=networks.ClipToSpec(act_spec)),
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.GaussianMixtureHead(num_dimensions=1,
                                     num_components=critic_num_components,
                                     init_scale=critic_init_scale)
    ])

    # Create network variables.
    # Get embedding spec by creating observation network variables.
    emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])
    tf2_utils.create_variables(policy_network, [emb_spec])
    tf2_utils.create_variables(critic_network, [emb_spec, act_spec])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
    }
Beispiel #7
0
 def __init__(self,
              env_spec: specs.EnvironmentSpec,
              encoder: types.TensorTransformation,
              name: str = 'MDP_normalization_layer'):
     super().__init__(name=name)
     obs_spec = env_spec.observations
     self._obs_mean = tf.Variable(tf.zeros(obs_spec.shape, obs_spec.dtype),
                                  name="obs_mean")
     self._obs_scale = tf.Variable(tf.ones(obs_spec.shape, obs_spec.dtype),
                                   name="obs_scale")
     self._ret_mean = tf.Variable(tf.zeros(1, obs_spec.dtype),
                                  name="ret_mean")
     self._ret_scale = tf.Variable(0.1 * tf.ones(1, obs_spec.dtype),
                                   name="ret_scale")
     self._encoder = tf2_utils.to_sonnet_module(encoder)
Beispiel #8
0
  def test_multiple_inputs_and_outputs(self):
    def transformation(aa, bb, cc):
      return (tf.concat([aa, bb, cc], axis=-1),
              tf.concat([bb, cc], axis=-1))

    model = tf2_utils.to_sonnet_module(transformation)
    dtype = np.float32
    input_spec = [specs.Array(shape=(2,), dtype=dtype),
                  specs.Array(shape=(3,), dtype=dtype),
                  specs.Array(shape=(4,), dtype=dtype)]
    expected_output_spec = (tf.TensorSpec(shape=(9,), dtype=dtype),
                            tf.TensorSpec(shape=(7,), dtype=dtype))
    output_spec = tf2_utils.create_variables(model, input_spec)
    self.assertEqual(model.variables, ())
    self.assertEqual(output_spec, expected_output_spec)
Beispiel #9
0
def make_d4pg_agent(env_spec: specs.EnvironmentSpec, logger: Logger, checkpoint_path: str, hyperparams: Dict):
    params = DEFAULT_PARAMS.copy()
    params.update(hyperparams)
    action_size = np.prod(env_spec.actions.shape, dtype=int).item()
    policy_network = snt.Sequential([
        networks.LayerNormMLP(layer_sizes=[*params.pop('policy_layers'), action_size]),
        networks.NearZeroInitializedLinear(output_size=action_size),
        networks.TanhToSpec(env_spec.actions),
    ])

    critic_network = snt.Sequential([
        networks.CriticMultiplexer(
            critic_network=networks.LayerNormMLP(layer_sizes=[*params.pop('critic_layers'), 1])
        ),
        networks.DiscreteValuedHead(vmin=-100.0, vmax=100.0, num_atoms=params.pop('atoms'))
    ])

    observation_network = tf.identity

    # Make sure observation network is a Sonnet Module.
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    actor = FeedForwardActor(policy_network=snt.Sequential([
        observation_network,
        policy_network
    ]))


    # Create optimizers.
    policy_optimizer = Adam(params.pop('policy_lr'))
    critic_optimizer = Adam(params.pop('critic_lr'))

    # The learner updates the parameters (and initializes them).
    agent = D4PG(
        environment_spec=env_spec,
        policy_network=policy_network,
        critic_network=critic_network,
        observation_network=observation_network,
        policy_optimizer=policy_optimizer,
        critic_optimizer=critic_optimizer,
        logger=logger,
        checkpoint_path=checkpoint_path,
        **params
    )
    agent.__setattr__('eval_actor', actor)
    return agent
Beispiel #10
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        num_samples: int,
        target_policy_update_period: int,
        target_critic_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = tf.identity,
        target_observation_network: types.TensorTransformation = tf.identity,
        policy_loss_module: Optional[snt.Module] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        dual_optimizer: Optional[snt.Optimizer] = None,
        clipping: bool = True,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._num_samples = num_samples
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_policy_update_period = target_policy_update_period
        self._target_critic_update_period = target_critic_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self._policy_loss_module = policy_loss_module or losses.MPO(
            epsilon=1e-1,
            epsilon_penalty=1e-3,
            epsilon_mean=1e-3,
            epsilon_stddev=1e-6,
            init_log_temperature=1.,
            init_log_alpha_mean=1.,
            init_log_alpha_stddev=10.)

        # Create the optimizers.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='dmpo_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'dual_optimizer': self._dual_optimizer,
                    'policy_loss_module': self._policy_loss_module,
                    'num_steps': self._num_steps,
                })

            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save={
                    'policy':
                    snt.Sequential([
                        self._target_observation_network,
                        self._target_policy_network
                    ]),
                })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Beispiel #11
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        target_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = lambda x: x,
        target_observation_network: types.TensorTransformation = lambda x: x,
        policy_optimizer: snt.Optimizer = None,
        critic_optimizer: snt.Optimizer = None,
        clipping: bool = True,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        checkpoint: bool = True,
    ):
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the critic loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_update_period = target_update_period

        # Create an iterator to go through the dataset.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        # Create optimizers if they aren't given.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        self._checkpointer = tf2_savers.Checkpointer(
            time_delta_minutes=5,
            objects_to_save={
                'counter': self._counter,
                'policy': self._policy_network,
                'critic': self._critic_network,
                'target_policy': self._target_policy_network,
                'target_critic': self._target_critic_network,
                'policy_optimizer': self._policy_optimizer,
                'critic_optimizer': self._critic_optimizer,
                'num_steps': self._num_steps,
            },
            enable_checkpointing=checkpoint,
        )

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Beispiel #12
0
  def __init__(
      self,
      policy_network: snt.Module,
      critic_network: snt.Module,
      target_policy_network: snt.Module,
      target_critic_network: snt.Module,
      discount: float,
      target_update_period: int,
      dataset_iterator: Iterator[reverb.ReplaySample],
      observation_network: types.TensorTransformation = lambda x: x,
      target_observation_network: types.TensorTransformation = lambda x: x,
      policy_optimizer: snt.Optimizer = None,
      critic_optimizer: snt.Optimizer = None,
      clipping: bool = True,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      checkpoint: bool = True,
  ):
    """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset_iterator: dataset to learn from, whether fixed or from a replay
        buffer (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

    # Store online and target networks.
    self._policy_network = policy_network
    self._critic_network = critic_network
    self._target_policy_network = target_policy_network
    self._target_critic_network = target_critic_network

    # Make sure observation networks are snt.Module's so they have variables.
    self._observation_network = tf2_utils.to_sonnet_module(observation_network)
    self._target_observation_network = tf2_utils.to_sonnet_module(
        target_observation_network)

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')

    # Other learner parameters.
    self._discount = discount
    self._clipping = clipping

    # Necessary to track when to update target networks.
    self._num_steps = tf.Variable(0, dtype=tf.int32)
    self._target_update_period = target_update_period

    # Batch dataset and create iterator.
    self._iterator = dataset_iterator

    # Create optimizers if they aren't given.
    self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
    self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

    # Expose the variables.
    policy_network_to_expose = snt.Sequential(
        [self._target_observation_network, self._target_policy_network])
    self._variables = {
        'critic': self._target_critic_network.variables,
        'policy': policy_network_to_expose.variables,
    }

    # Create a checkpointer and snapshotter objects.
    self._checkpointer = None
    self._snapshotter = None

    if checkpoint:
      self._checkpointer = tf2_savers.Checkpointer(
          subdirectory='d4pg_learner',
          objects_to_save={
              'counter': self._counter,
              'policy': self._policy_network,
              'critic': self._critic_network,
              'observation': self._observation_network,
              'target_policy': self._target_policy_network,
              'target_critic': self._target_critic_network,
              'target_observation': self._target_observation_network,
              'policy_optimizer': self._policy_optimizer,
              'critic_optimizer': self._critic_optimizer,
              'num_steps': self._num_steps,
          })
      critic_mean = snt.Sequential(
          [self._critic_network, acme_nets.StochasticMeanHead()])
      self._snapshotter = tf2_savers.Snapshotter(
          objects_to_save={
              'policy': self._policy_network,
              'critic': critic_mean,
          })

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Beispiel #13
0
    client=reverb.Client(replay_server_address),
    n_step=5,
    discount=0.99)

# This connects to the created reverb server; also note that we use a transition
# adder above so we'll tell the dataset function that so that it knows the type
# of data that's coming out.
dataset = datasets.make_reverb_dataset(
    table=replay_table_name,
    server_address=replay_server_address,
    batch_size=256,
    prefetch_size=True)

# Make sure observation network is a Sonnet Module.
observation_network = tf2_utils.batch_concat
observation_network = tf2_utils.to_sonnet_module(observation_network)

# Create the target networks
target_policy_network = copy.deepcopy(policy_network)
target_critic_network = copy.deepcopy(critic_network)
target_observation_network = copy.deepcopy(observation_network)

# Get observation and action specs.
act_spec = environment_spec.actions
obs_spec = environment_spec.observations
emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

# Create the behavior policy.
behavior_network = snt.Sequential([
    observation_network,
    policy_network,
Beispiel #14
0
def make_default_networks(
    environment_spec: mava_specs.MAEnvironmentSpec,
    policy_networks_layer_sizes: Union[Dict[str, Sequence],
                                       Sequence] = (256, 256, 256),
    critic_networks_layer_sizes: Union[Dict[str, Sequence],
                                       Sequence] = (512, 512, 256),
    shared_weights: bool = True,
    sigma: float = 0.3,
    archecture_type: ArchitectureType = ArchitectureType.feedforward,
) -> Mapping[str, types.TensorTransformation]:
    """Default networks for maddpg.

    Args:
        environment_spec (mava_specs.MAEnvironmentSpec): description of the action and
            observation spaces etc. for each agent in the system.
        policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional):
            size of policy networks. Defaults to (256, 256, 256).
        critic_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional):
            size of critic networks. Defaults to (512, 512, 256).
        shared_weights (bool, optional): whether agents should share weights or not.
            Defaults to True.
        sigma (float, optional): hyperparameters used to add Gaussian noise for
            simple exploration. Defaults to 0.3.
        archecture_type (ArchitectureType, optional): archecture used for
            agent networks. Can be feedforward or recurrent. Defaults to
            ArchitectureType.feedforward.

    Returns:
        Mapping[str, types.TensorTransformation]: returned agent networks.
    """

    # Set Policy function and layer size
    if archecture_type == ArchitectureType.feedforward:
        policy_network_func = snt.Sequential
    elif archecture_type == ArchitectureType.recurrent:
        policy_networks_layer_sizes = (128, 128)
        policy_network_func = snt.DeepRNN

    specs = environment_spec.get_agent_specs()

    # Create agent_type specs
    if shared_weights:
        type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()}
        specs = type_specs

    if isinstance(policy_networks_layer_sizes, Sequence):
        policy_networks_layer_sizes = {
            key: policy_networks_layer_sizes
            for key in specs.keys()
        }
    if isinstance(critic_networks_layer_sizes, Sequence):
        critic_networks_layer_sizes = {
            key: critic_networks_layer_sizes
            for key in specs.keys()
        }

    observation_networks = {}
    policy_networks = {}
    critic_networks = {}
    for key in specs.keys():
        # TODO (dries): Make specs[key].actions
        #  return a list of specs for hybrid action space
        # Get total number of action dimensions from action spec.
        agent_act_spec = specs[key].actions
        if type(specs[key].actions) == DiscreteArray:
            num_actions = agent_act_spec.num_values
            minimum = [-1.0] * num_actions
            maximum = [1.0] * num_actions
            agent_act_spec = BoundedArray(
                shape=(num_actions, ),
                minimum=minimum,
                maximum=maximum,
                dtype="float32",
                name="actions",
            )

        # Get total number of action dimensions from action spec.
        num_dimensions = np.prod(agent_act_spec.shape, dtype=int)

        # An optional network to process observations
        observation_network = tf2_utils.to_sonnet_module(tf.identity)
        # Create the policy network.
        if archecture_type == ArchitectureType.feedforward:
            policy_network = [
                networks.LayerNormMLP(policy_networks_layer_sizes[key],
                                      activate_final=True),
            ]
        elif archecture_type == ArchitectureType.recurrent:
            policy_network = [
                networks.LayerNormMLP(policy_networks_layer_sizes[key][:-1],
                                      activate_final=True),
                snt.LSTM(policy_networks_layer_sizes[key][-1]),
            ]

        policy_network += [
            networks.NearZeroInitializedLinear(num_dimensions),
            networks.TanhToSpec(agent_act_spec),
        ]

        # Add Gaussian noise for simple exploration.
        if sigma and sigma > 0.0:
            policy_network += [
                networks.ClippedGaussian(sigma),
                networks.ClipToSpec(agent_act_spec),
            ]

        policy_network = policy_network_func(policy_network)

        # Create the critic network.
        critic_network = snt.Sequential([
            # The multiplexer concatenates the observations/actions.
            networks.CriticMultiplexer(),
            networks.LayerNormMLP(list(critic_networks_layer_sizes[key]) + [1],
                                  activate_final=False),
        ])
        observation_networks[key] = observation_network
        policy_networks[key] = policy_network
        critic_networks[key] = critic_network

    return {
        "policies": policy_networks,
        "critics": critic_networks,
        "observations": observation_networks,
    }
Beispiel #15
0
def make_acme_agent(environment_spec,
                    residual_spec,
                    obs_network_type,
                    crop_frames,
                    full_image_size,
                    crop_margin_size,
                    late_fusion,
                    binary_grip_action=False,
                    input_type=None,
                    counter=None,
                    logdir=None,
                    agent_logger=None):
    """Initialize acme agent based on residual spec and agent flags."""
    # TODO(minttu): Is environment_spec needed or could we use residual_spec?
    del logdir  # Setting logdir for the learner ckpts not currently supported.
    obs_network = None
    if obs_network_type is not None:
        obs_network = agents.ObservationNet(network_type=obs_network_type,
                                            input_type=input_type,
                                            add_linear_layer=False,
                                            crop_frames=crop_frames,
                                            full_image_size=full_image_size,
                                            crop_margin_size=crop_margin_size,
                                            late_fusion=late_fusion)

    eval_policy = None
    if FLAGS.agent == 'MPO':
        agent_networks = networks.make_mpo_networks(
            environment_spec.actions,
            policy_init_std=FLAGS.policy_init_std,
            obs_network=obs_network)

        rl_agent = mpo.MPO(
            environment_spec=residual_spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            observation_network=agent_networks['observation'],
            discount=FLAGS.discount,
            batch_size=FLAGS.rl_batch_size,
            min_replay_size=FLAGS.min_replay_size,
            max_replay_size=FLAGS.max_replay_size,
            policy_optimizer=snt.optimizers.Adam(FLAGS.policy_rl),
            critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr),
            counter=counter,
            logger=agent_logger,
            checkpoint=FLAGS.write_acme_checkpoints,
        )
    elif FLAGS.agent == 'DMPO':
        agent_networks = networks.make_dmpo_networks(
            environment_spec.actions,
            policy_layer_sizes=FLAGS.rl_policy_layer_sizes,
            critic_layer_sizes=FLAGS.rl_critic_layer_sizes,
            vmin=FLAGS.critic_vmin,
            vmax=FLAGS.critic_vmax,
            num_atoms=FLAGS.critic_num_atoms,
            policy_init_std=FLAGS.policy_init_std,
            binary_grip_action=binary_grip_action,
            obs_network=obs_network)

        # spec = residual_spec if obs_network is None else environment_spec
        spec = residual_spec
        rl_agent = dmpo.DistributionalMPO(
            environment_spec=spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            observation_network=agent_networks['observation'],
            discount=FLAGS.discount,
            batch_size=FLAGS.rl_batch_size,
            min_replay_size=FLAGS.min_replay_size,
            max_replay_size=FLAGS.max_replay_size,
            policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr),
            critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr),
            counter=counter,
            # logdir=logdir,
            logger=agent_logger,
            checkpoint=FLAGS.write_acme_checkpoints,
        )
        # Learned policy without exploration.
        eval_policy = (tf.function(
            snt.Sequential([
                tf_utils.to_sonnet_module(agent_networks['observation']),
                agent_networks['policy'],
                tf_networks.StochasticMeanHead()
            ])))
    elif FLAGS.agent == 'D4PG':
        agent_networks = networks.make_d4pg_networks(
            residual_spec.actions,
            vmin=FLAGS.critic_vmin,
            vmax=FLAGS.critic_vmax,
            num_atoms=FLAGS.critic_num_atoms,
            policy_weights_init_scale=FLAGS.policy_weights_init_scale,
            obs_network=obs_network)

        # TODO(minttu): downscale action space to [-1, 1] to match clipped gaussian.
        rl_agent = d4pg.D4PG(
            environment_spec=residual_spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            observation_network=agent_networks['observation'],
            discount=FLAGS.discount,
            batch_size=FLAGS.rl_batch_size,
            min_replay_size=FLAGS.min_replay_size,
            max_replay_size=FLAGS.max_replay_size,
            policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr),
            critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr),
            sigma=FLAGS.policy_init_std,
            counter=counter,
            logger=agent_logger,
            checkpoint=FLAGS.write_acme_checkpoints,
        )

        # Learned policy without exploration.
        eval_policy = tf.function(
            snt.Sequential([
                tf_utils.to_sonnet_module(agent_networks['observation']),
                agent_networks['policy']
            ]))

    else:
        raise NotImplementedError('Supported agents: MPO, DMPO, D4PG.')
    return rl_agent, eval_policy
Beispiel #16
0
    def __init__(
        self,
        reward_objectives: Sequence[RewardObjective],
        qvalue_objectives: Sequence[QValueObjective],
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        num_samples: int,
        target_policy_update_period: int,
        target_critic_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = tf.identity,
        target_observation_network: types.TensorTransformation = tf.identity,
        policy_loss_module: Optional[losses.MultiObjectiveMPO] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        dual_optimizer: Optional[snt.Optimizer] = None,
        clipping: bool = True,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._num_samples = num_samples
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_policy_update_period = target_policy_update_period
        self._target_critic_update_period = target_critic_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        # Store objectives
        self._reward_objectives = reward_objectives
        self._qvalue_objectives = qvalue_objectives
        if self._qvalue_objectives is None:
            self._qvalue_objectives = []
        self._num_critic_heads = len(self._reward_objectives)  # C
        self._objective_names = ([x.name for x in self._reward_objectives] +
                                 [x.name for x in self._qvalue_objectives])

        self._policy_loss_module = policy_loss_module or losses.MultiObjectiveMPO(
            epsilons=[
                losses.KLConstraint(name, _DEFAULT_EPSILON)
                for name in self._objective_names
            ],
            epsilon_mean=_DEFAULT_EPSILON_MEAN,
            epsilon_stddev=_DEFAULT_EPSILON_STDDEV,
            init_log_temperature=_DEFAULT_INIT_LOG_TEMPERATURE,
            init_log_alpha_mean=_DEFAULT_INIT_LOG_ALPHA_MEAN,
            init_log_alpha_stddev=_DEFAULT_INIT_LOG_ALPHA_STDDEV)

        # Check that ordering of objectives matches the policy_loss_module's
        if self._objective_names != list(
                self._policy_loss_module.objective_names):
            raise ValueError("Agent's ordering of objectives doesn't match "
                             "the policy loss module's ordering of epsilons.")

        # Create the optimizers.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='dmpo_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'dual_optimizer': self._dual_optimizer,
                    'policy_loss_module': self._policy_loss_module,
                    'num_steps': self._num_steps,
                })

            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save={
                    'policy':
                    snt.Sequential([
                        self._target_observation_network,
                        self._target_policy_network
                    ]),
                })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp: float = None
Beispiel #17
0
    def __init__(self,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 target_policy_network: snt.Module,
                 target_critic_network: snt.Module,
                 discount: float,
                 target_update_period: int,
                 dataset: tf.data.Dataset,
                 observation_network: types.TensorTransformation = lambda x: x,
                 target_observation_network: types.
                 TensorTransformation = lambda x: x,
                 policy_optimizer: snt.Optimizer = None,
                 critic_optimizer: snt.Optimizer = None,
                 clipping: bool = True,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 checkpoint: bool = True,
                 specified_path: str = None):
        # print('\033[94m I am sub_virtual acme d4pg learning\033[0m')
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_update_period = target_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        # Create optimizers if they aren't given.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter objects.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='d4pg_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'num_steps': self._num_steps,
                })
            # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/efc0f104-d4a6-11eb-9d04-04d4c40103a8/checkpoints/d4pg_learner/checkpoint/ckpt-1')
            # self._checkpointer._checkpoint.restore('/home/argsubt/acme/f397d4d6-edf2-11eb-a739-04d4c40103a8/checkpoints/d4pg_learner/ckpt-1')
            # print('\033[92mload checkpoints~\033[0m')
            # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/4346ec84-ee10-11eb-8185-04d4c40103a8/checkpoints/d4pg_learner/ckpt-532')
            self.specified_path = specified_path
            if self.specified_path != None:
                self._checkpointer._checkpoint.restore(self.specified_path)
                print('\033[92mspecified_path: ', str(self.specified_path),
                      '\033[0m')
            critic_mean = snt.Sequential(
                [self._critic_network,
                 acme_nets.StochasticMeanHead()])
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'policy': self._policy_network,
                'critic': critic_mean,
            })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Beispiel #18
0
def make_default_networks(
    environment_spec: mava_specs.MAEnvironmentSpec,
    policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (256, 256, 256),
    critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256),
    shared_weights: bool = True,
) -> Dict[str, snt.Module]:
    """Default networks for mappo.

    Args:
        environment_spec (mava_specs.MAEnvironmentSpec): description of the action and
            observation spaces etc. for each agent in the system.
        policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional):
            size of policy networks. Defaults to (256, 256, 256).
        critic_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional):
            size of critic networks. Defaults to (512, 512, 256).
        shared_weights (bool, optional): whether agents should share weights or not.
            Defaults to True.

    Raises:
        ValueError: Unknown action_spec type, if actions aren't DiscreteArray
            or BoundedArray.

    Returns:
        Dict[str, snt.Module]: returned agent networks.
    """

    # Create agent_type specs.
    specs = environment_spec.get_agent_specs()
    if shared_weights:
        type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()}
        specs = type_specs

    if isinstance(policy_networks_layer_sizes, Sequence):
        policy_networks_layer_sizes = {
            key: policy_networks_layer_sizes for key in specs.keys()
        }
    if isinstance(critic_networks_layer_sizes, Sequence):
        critic_networks_layer_sizes = {
            key: critic_networks_layer_sizes for key in specs.keys()
        }

    observation_networks = {}
    policy_networks = {}
    critic_networks = {}
    for key in specs.keys():

        # Create the shared observation network; here simply a state-less operation.
        observation_network = tf2_utils.to_sonnet_module(tf.identity)

        # Note: The discrete case must be placed first as it inherits from BoundedArray.
        if isinstance(specs[key].actions, dm_env.specs.DiscreteArray):  # discrete
            num_actions = specs[key].actions.num_values
            policy_network = snt.Sequential(
                [
                    networks.LayerNormMLP(
                        tuple(policy_networks_layer_sizes[key]) + (num_actions,),
                        activate_final=False,
                    ),
                    tf.keras.layers.Lambda(
                        lambda logits: tfp.distributions.Categorical(logits=logits)
                    ),
                ]
            )
        elif isinstance(specs[key].actions, dm_env.specs.BoundedArray):  # continuous
            num_actions = np.prod(specs[key].actions.shape, dtype=int)
            policy_network = snt.Sequential(
                [
                    networks.LayerNormMLP(
                        policy_networks_layer_sizes[key], activate_final=True
                    ),
                    networks.MultivariateNormalDiagHead(num_dimensions=num_actions),
                    networks.TanhToSpec(specs[key].actions),
                ]
            )
        else:
            raise ValueError(f"Unknown action_spec type, got {specs[key].actions}.")

        critic_network = snt.Sequential(
            [
                networks.LayerNormMLP(
                    list(critic_networks_layer_sizes[key]) + [1], activate_final=False
                ),
            ]
        )

        observation_networks[key] = observation_network
        policy_networks[key] = policy_network
        critic_networks[key] = critic_network
    return {
        "policies": policy_networks,
        "critics": critic_networks,
        "observations": observation_networks,
    }
Beispiel #19
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        policy_network: snt.Module,
        critic_network: snt.Module,
        observation_network: types.TensorTransformation = tf.identity,
        discount: float = 0.99,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_policy_update_period: int = 100,
        target_critic_update_period: int = 100,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        policy_loss_module: snt.Module = None,
        policy_optimizer: snt.Optimizer = None,
        critic_optimizer: snt.Optimizer = None,
        n_step: int = 5,
        num_samples: int = 20,
        clipping: bool = True,
        logger: loggers.Logger = None,
        counter: counting.Counter = None,
        checkpoint: bool = True,
        replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_policy_update_period: number of updates to perform before updating
        the target policy network.
      target_critic_update_period: number of updates to perform before updating
        the target critic network.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      policy_loss_module: configured MPO loss function for the policy
        optimization; defaults to sensible values on the control suite.
        See `acme/tf/losses/mpo.py` for more details.
      policy_optimizer: optimizer to be used on the policy.
      critic_optimizer: optimizer to be used on the critic.
      n_step: number of steps to squash into a single transition.
      num_samples: number of actions to sample when doing a Monte Carlo
        integration with respect to the policy.
      clipping: whether to clip gradients by global norm.
      logger: logging object used to write to logs.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """

        # Create a replay server to add data to.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            environment_spec=environment_spec,
            transition_adder=True)

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks before creating online/target network variables.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.StochasticSamplingHead(),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network=behavior_network,
                                        adder=adder)

        # Create optimizers.
        policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.MPOLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_loss_module=policy_loss_module,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            num_samples=num_samples,
            target_policy_update_period=target_policy_update_period,
            target_critic_update_period=target_critic_update_period,
            dataset=dataset,
            logger=logger,
            counter=counter,
            checkpoint=checkpoint)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Beispiel #20
0
    def learner(
        self,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        act_spec = self._environment_spec.actions
        obs_spec = self._environment_spec.observations

        # Create the networks to optimize (online) and target networks.
        online_networks = self._network_factory(act_spec)
        target_networks = self._network_factory(act_spec)

        # Make sure observation network is a Sonnet Module.
        observation_network = online_networks.get('observation', tf.identity)
        target_observation_network = target_networks.get(
            'observation', tf.identity)
        observation_network = tf2_utils.to_sonnet_module(observation_network)
        target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # Get embedding spec and create observation network variables.
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        # Create variables.
        tf2_utils.create_variables(online_networks['policy'], [emb_spec])
        tf2_utils.create_variables(online_networks['critic'],
                                   [emb_spec, act_spec])
        tf2_utils.create_variables(target_networks['policy'], [emb_spec])
        tf2_utils.create_variables(target_networks['critic'],
                                   [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address,
            batch_size=self._batch_size,
            prefetch_size=self._prefetch_size)

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4)
        critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4)

        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger('learner',
                                             time_delta=self._log_every,
                                             steps_key='learner_steps')

        # Return the learning agent.
        return learning.DDPGLearner(
            policy_network=online_networks['policy'],
            critic_network=online_networks['critic'],
            observation_network=observation_network,
            target_policy_network=target_networks['policy'],
            target_critic_network=target_networks['critic'],
            target_observation_network=target_observation_network,
            discount=self._discount,
            target_update_period=self._target_update_period,
            dataset=dataset,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=self._clipping,
            counter=counter,
            logger=logger,
        )
Beispiel #21
0
    def learner(
        self,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        act_spec = self._environment_spec.actions
        obs_spec = self._environment_spec.observations

        # Create online and target networks.
        online_networks = self._network_factory(act_spec)
        target_networks = self._network_factory(act_spec)

        # Make sure observation networks are Sonnet Modules.
        observation_network = online_networks.get('observation', tf.identity)
        observation_network = tf2_utils.to_sonnet_module(observation_network)
        online_networks['observation'] = observation_network
        target_observation_network = target_networks.get(
            'observation', tf.identity)
        target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)
        target_networks['observation'] = target_observation_network

        # Get embedding spec and create observation network variables.
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        tf2_utils.create_variables(online_networks['policy'], [emb_spec])
        tf2_utils.create_variables(online_networks['critic'],
                                   [emb_spec, act_spec])
        tf2_utils.create_variables(target_networks['observation'], [obs_spec])
        tf2_utils.create_variables(target_networks['policy'], [emb_spec])
        tf2_utils.create_variables(target_networks['critic'],
                                   [emb_spec, act_spec])

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address)
        dataset = dataset.batch(self._batch_size, drop_remainder=True)
        dataset = dataset.prefetch(self._prefetch_size)

        # Create a counter and logger for bookkeeping steps and performance.
        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger('learner',
                                             time_delta=self._log_every,
                                             steps_key='learner_steps')

        # Create policy loss module if a factory is passed.
        if self._policy_loss_factory:
            policy_loss_module = self._policy_loss_factory()
        else:
            policy_loss_module = None

        # Return the learning agent.
        return learning.MPOLearner(
            policy_network=online_networks['policy'],
            critic_network=online_networks['critic'],
            observation_network=observation_network,
            target_policy_network=target_networks['policy'],
            target_critic_network=target_networks['critic'],
            target_observation_network=target_observation_network,
            discount=self._additional_discount,
            num_samples=self._num_samples,
            target_policy_update_period=self._target_policy_update_period,
            target_critic_update_period=self._target_critic_update_period,
            policy_loss_module=policy_loss_module,
            dataset=dataset,
            counter=counter,
            logger=logger)
Beispiel #22
0
def make_networks(
    environment_spec: mava_specs.MAEnvironmentSpec,
    policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (
        256,
        256,
        256,
    ),
    critic_networks_layer_sizes: Union[Dict[str, Sequence],
                                       Sequence] = (512, 512, 256),
    shared_weights: bool = True,
) -> Dict[str, snt.Module]:
    """Creates networks used by the agents."""

    # Create agent_type specs.
    specs = environment_spec.get_agent_specs()
    if shared_weights:
        type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()}
        specs = type_specs

    if isinstance(policy_networks_layer_sizes, Sequence):
        policy_networks_layer_sizes = {
            key: policy_networks_layer_sizes
            for key in specs.keys()
        }
    if isinstance(critic_networks_layer_sizes, Sequence):
        critic_networks_layer_sizes = {
            key: critic_networks_layer_sizes
            for key in specs.keys()
        }

    observation_networks = {}
    policy_networks = {}
    critic_networks = {}
    for key in specs.keys():

        # Create the shared observation network; here simply a state-less operation.
        observation_network = tf2_utils.to_sonnet_module(tf.identity)

        # Note: The discrete case must be placed first as it inherits from BoundedArray.
        if isinstance(specs[key].actions,
                      dm_env.specs.DiscreteArray):  # discrete
            num_actions = specs[key].actions.num_values
            policy_network = snt.Sequential([
                networks.LayerNormMLP(
                    tuple(policy_networks_layer_sizes[key]) + (num_actions, ),
                    activate_final=False,
                ),
                tf.keras.layers.Lambda(lambda logits: tfp.distributions.
                                       Categorical(logits=logits)),
            ])
        elif isinstance(specs[key].actions,
                        dm_env.specs.BoundedArray):  # continuous
            num_actions = np.prod(specs[key].actions.shape, dtype=int)
            policy_network = snt.Sequential([
                networks.LayerNormMLP(policy_networks_layer_sizes[key],
                                      activate_final=True),
                networks.MultivariateNormalDiagHead(
                    num_dimensions=num_actions),
                networks.TanhToSpec(specs[key].actions),
            ])
        else:
            raise ValueError(
                f"Unknown action_spec type, got {specs[key].actions}.")

        critic_network = snt.Sequential([
            networks.LayerNormMLP(critic_networks_layer_sizes[key],
                                  activate_final=True),
            networks.NearZeroInitializedLinear(1),
        ])

        observation_networks[key] = observation_network
        policy_networks[key] = policy_network
        critic_networks[key] = critic_network

    return {
        "policies": policy_networks,
        "critics": critic_networks,
        "observations": observation_networks,
    }
Beispiel #23
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 observation_network: types.TensorTransformation = tf.identity,
                 discount: float = 0.99,
                 batch_size: int = 256,
                 prefetch_size: int = 4,
                 target_update_period: int = 100,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0,
                 n_step: int = 5,
                 sigma: float = 0.3,
                 clipping: bool = True,
                 logger: loggers.Logger = None,
                 counter: counting.Counter = None,
                 checkpoint: bool = True,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      n_step: number of steps to squash into a single transition.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      clipping: whether to clip gradients by global norm.
      logger: logger object to be used by learner.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """
        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(
            priority_fns={replay_table_name: lambda x: 1.},
            client=reverb.Client(address),
            n_step=n_step,
            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])  # pytype: disable=wrong-arg-types

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.ClippedGaussian(sigma),
            networks.ClipToSpec(act_spec),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(behavior_network, adder=adder)

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4)
        critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.DDPGLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            target_update_period=target_update_period,
            dataset=dataset,
            counter=counter,
            logger=logger,
            checkpoint=checkpoint,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Beispiel #24
0
def make_networks(
    environment_spec: mava_specs.MAEnvironmentSpec,
    policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (
        256,
        256,
        256,
    ),
    critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256),
    shared_weights: bool = True,
    sigma: float = 0.3,
) -> Mapping[str, types.TensorTransformation]:
    """Creates networks used by the agents."""
    specs = environment_spec.get_agent_specs()

    # Create agent_type specs
    if shared_weights:
        type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()}
        specs = type_specs

    if isinstance(policy_networks_layer_sizes, Sequence):
        policy_networks_layer_sizes = {
            key: policy_networks_layer_sizes for key in specs.keys()
        }
    if isinstance(critic_networks_layer_sizes, Sequence):
        critic_networks_layer_sizes = {
            key: critic_networks_layer_sizes for key in specs.keys()
        }

    observation_networks = {}
    policy_networks = {}
    critic_networks = {}
    for key in specs.keys():

        # Get total number of action dimensions from action spec.
        num_dimensions = np.prod(specs[key].actions.shape, dtype=int)

        # Create the shared observation network; here simply a state-less operation.
        observation_network = tf2_utils.to_sonnet_module(tf.identity)

        # Create the policy network.
        policy_network = snt.Sequential(
            [
                networks.LayerNormMLP(
                    policy_networks_layer_sizes[key], activate_final=True
                ),
                networks.NearZeroInitializedLinear(num_dimensions),
                networks.TanhToSpec(specs[key].actions),
                networks.ClippedGaussian(sigma),
                networks.ClipToSpec(specs[key].actions),
            ]
        )

        # Create the critic network.
        critic_network = snt.Sequential(
            [
                # The multiplexer concatenates the observations/actions.
                networks.CriticMultiplexer(),
                networks.LayerNormMLP(
                    critic_networks_layer_sizes[key], activate_final=False
                ),
                snt.Linear(1),
            ]
        )
        observation_networks[key] = observation_network
        policy_networks[key] = policy_network
        critic_networks[key] = critic_network

    return {
        "policies": policy_networks,
        "critics": critic_networks,
        "observations": observation_networks,
    }