def load_policy_net( task_name: str, noise_level: float, dataset_path: str, environment_spec: specs.EnvironmentSpec, near_policy_dataset: bool = False, ): dataset_path = Path(dataset_path) if task_name.startswith("bsuite"): # BSuite tasks. bsuite_id = task_name[len("bsuite_"):] + "/0" path = bsuite_policy_path( bsuite_id, noise_level, near_policy_dataset, dataset_path) logging.info("Policy path: %s", path) policy_net = tf.saved_model.load(path) policy_noise_level = 0.1 # params["policy_noise_level"] observation_network = tf2_utils.to_sonnet_module(functools.partial( tf.reshape, shape=(-1,) + environment_spec.observations.shape)) policy_net = snt.Sequential([ observation_network, policy_net, # Uncomment this line to add action noise to the target policy. lambda q: trfl.epsilon_greedy(q, epsilon=policy_noise_level).sample(), ]) elif task_name.startswith("dm_control"): # DM Control tasks. if near_policy_dataset: raise ValueError( "Near-policy dataset is not available for dm_control tasks.") dm_control_task = task_name[len("dm_control_"):] path = dm_control_policy_path( dm_control_task, noise_level, dataset_path) logging.info("Policy path: %s", path) policy_net = tf.saved_model.load(path) policy_noise_level = 0.2 # params["policy_noise_level"] observation_network = tf2_utils.to_sonnet_module(tf2_utils.batch_concat) policy_net = snt.Sequential([ observation_network, policy_net, # Uncomment these two lines to add action noise to target policy. acme_utils.GaussianNoise(policy_noise_level), networks.ClipToSpec(environment_spec.actions), ]) else: raise ValueError(f"task name {task_name} is unsupported.") return policy_net
def test_scalar_output(self): model = tf2_utils.to_sonnet_module(tf.reduce_sum) input_spec = specs.Array(shape=(10,), dtype=np.float32) expected_spec = tf.TensorSpec(shape=(), dtype=tf.float32) output_spec = tf2_utils.create_variables(model, [input_spec]) self.assertEqual(model.variables, ()) self.assertEqual(output_spec, expected_spec)
def test_none_output(self): model = tf2_utils.to_sonnet_module(lambda x: None) input_spec = specs.Array(shape=(10,), dtype=np.float32) expected_spec = None output_spec = tf2_utils.create_variables(model, [input_spec]) self.assertEqual(model.variables, ()) self.assertEqual(output_spec, expected_spec)
def make_lstm_mpo_agent(env_spec: specs.EnvironmentSpec, logger: Logger, hyperparams: Dict, checkpoint_path: str): params = DEFAULT_PARAMS.copy() params.update(hyperparams) action_size = np.prod(env_spec.actions.shape, dtype=int).item() policy_network = snt.Sequential([ networks.LayerNormMLP( layer_sizes=[*params.pop('policy_layers'), action_size]), networks.MultivariateNormalDiagHead(num_dimensions=action_size) ]) critic_network = snt.Sequential([ networks.CriticMultiplexer(critic_network=networks.LayerNormMLP( layer_sizes=[*params.pop('critic_layers'), 1])) ]) observation_network = networks.DeepRNN([ networks.LayerNormMLP(layer_sizes=params.pop('observation_layers')), networks.LSTM(hidden_size=200) ]) loss_param_keys = list( filter(lambda key: key.startswith('loss_'), params.keys())) loss_params = dict([(k.replace('loss_', ''), params.pop(k)) for k in loss_param_keys]) policy_loss_module = losses.MPO(**loss_params) # Create a replay server to add data to. # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create optimizers. policy_optimizer = Adam(params.pop('policy_lr')) critic_optimizer = Adam(params.pop('critic_lr')) actor = RecurrentActor( networks.DeepRNN([ observation_network, policy_network, networks.StochasticModeHead() ])) # The learner updates the parameters (and initializes them). return RecurrentMPO(environment_spec=env_spec, policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, policy_loss_module=policy_loss_module, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, logger=logger, checkpoint_path=checkpoint_path, **params), actor
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation, ): # This method is implemented (rather than added by the dataclass decorator) # in order to allow observation network to be passed as an arbitrary tensor # transformation rather than as a snt Module. # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. self.policy_network = policy_network self.critic_network = critic_network self.observation_network = utils.to_sonnet_module(observation_network)
def make_default_networks( environment_spec: specs.EnvironmentSpec, *, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), policy_init_scale: float = 0.7, critic_init_scale: float = 1e-3, critic_num_components: int = 5, ) -> Mapping[str, snt.Module]: """Creates networks used by the agent.""" # Unpack the environment spec to get appropriate shapes, dtypes, etc. act_spec = environment_spec.actions obs_spec = environment_spec.observations num_dimensions = np.prod(act_spec.shape, dtype=int) # Create the observation network and make sure it's a Sonnet module. observation_network = tf2_utils.batch_concat observation_network = tf2_utils.to_sonnet_module(observation_network) # Create the policy network. policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead(num_dimensions, init_scale=policy_init_scale, use_tfd_independent=True) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = snt.Sequential([ networks.CriticMultiplexer( action_network=networks.ClipToSpec(act_spec)), networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.GaussianMixtureHead(num_dimensions=1, num_components=critic_num_components, init_scale=critic_init_scale) ]) # Create network variables. # Get embedding spec by creating observation network variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, }
def __init__(self, env_spec: specs.EnvironmentSpec, encoder: types.TensorTransformation, name: str = 'MDP_normalization_layer'): super().__init__(name=name) obs_spec = env_spec.observations self._obs_mean = tf.Variable(tf.zeros(obs_spec.shape, obs_spec.dtype), name="obs_mean") self._obs_scale = tf.Variable(tf.ones(obs_spec.shape, obs_spec.dtype), name="obs_scale") self._ret_mean = tf.Variable(tf.zeros(1, obs_spec.dtype), name="ret_mean") self._ret_scale = tf.Variable(0.1 * tf.ones(1, obs_spec.dtype), name="ret_scale") self._encoder = tf2_utils.to_sonnet_module(encoder)
def test_multiple_inputs_and_outputs(self): def transformation(aa, bb, cc): return (tf.concat([aa, bb, cc], axis=-1), tf.concat([bb, cc], axis=-1)) model = tf2_utils.to_sonnet_module(transformation) dtype = np.float32 input_spec = [specs.Array(shape=(2,), dtype=dtype), specs.Array(shape=(3,), dtype=dtype), specs.Array(shape=(4,), dtype=dtype)] expected_output_spec = (tf.TensorSpec(shape=(9,), dtype=dtype), tf.TensorSpec(shape=(7,), dtype=dtype)) output_spec = tf2_utils.create_variables(model, input_spec) self.assertEqual(model.variables, ()) self.assertEqual(output_spec, expected_output_spec)
def make_d4pg_agent(env_spec: specs.EnvironmentSpec, logger: Logger, checkpoint_path: str, hyperparams: Dict): params = DEFAULT_PARAMS.copy() params.update(hyperparams) action_size = np.prod(env_spec.actions.shape, dtype=int).item() policy_network = snt.Sequential([ networks.LayerNormMLP(layer_sizes=[*params.pop('policy_layers'), action_size]), networks.NearZeroInitializedLinear(output_size=action_size), networks.TanhToSpec(env_spec.actions), ]) critic_network = snt.Sequential([ networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(layer_sizes=[*params.pop('critic_layers'), 1]) ), networks.DiscreteValuedHead(vmin=-100.0, vmax=100.0, num_atoms=params.pop('atoms')) ]) observation_network = tf.identity # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) actor = FeedForwardActor(policy_network=snt.Sequential([ observation_network, policy_network ])) # Create optimizers. policy_optimizer = Adam(params.pop('policy_lr')) critic_optimizer = Adam(params.pop('critic_lr')) # The learner updates the parameters (and initializes them). agent = D4PG( environment_spec=env_spec, policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, logger=logger, checkpoint_path=checkpoint_path, **params ) agent.__setattr__('eval_actor', actor) return agent
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, num_samples: int, target_policy_update_period: int, target_critic_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = tf.identity, target_observation_network: types.TensorTransformation = tf.identity, policy_loss_module: Optional[snt.Module] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, dual_optimizer: Optional[snt.Optimizer] = None, clipping: bool = True, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._num_samples = num_samples self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_policy_update_period = target_policy_update_period self._target_critic_update_period = target_critic_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._policy_loss_module = policy_loss_module or losses.MPO( epsilon=1e-1, epsilon_penalty=1e-3, epsilon_mean=1e-3, epsilon_stddev=1e-6, init_log_temperature=1., init_log_alpha_mean=1., init_log_alpha_stddev=10.) # Create the optimizers. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='dmpo_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'dual_optimizer': self._dual_optimizer, 'policy_loss_module': self._policy_loss_module, 'num_steps': self._num_steps, }) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': snt.Sequential([ self._target_observation_network, self._target_policy_network ]), }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = lambda x: x, target_observation_network: types.TensorTransformation = lambda x: x, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, ): """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). observation_network: an optional online network to process observations before the policy and the critic. target_observation_network: the target observation network. policy_optimizer: the optimizer to be applied to the DPG (policy) loss. critic_optimizer: the optimizer to be applied to the critic loss. clipping: whether to clip gradients by global norm. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Create an iterator to go through the dataset. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': target_critic_network.variables, 'policy': policy_network_to_expose.variables, } self._checkpointer = tf2_savers.Checkpointer( time_delta_minutes=5, objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }, enable_checkpointing=checkpoint, ) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset_iterator: Iterator[reverb.ReplaySample], observation_network: types.TensorTransformation = lambda x: x, target_observation_network: types.TensorTransformation = lambda x: x, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, ): """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. dataset_iterator: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). observation_network: an optional online network to process observations before the policy and the critic. target_observation_network: the target observation network. policy_optimizer: the optimizer to be applied to the DPG (policy) loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. clipping: whether to clip gradients by global norm. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module(observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. self._iterator = dataset_iterator # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter objects. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='d4pg_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }) critic_mean = snt.Sequential( [self._critic_network, acme_nets.StochasticMeanHead()]) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': self._policy_network, 'critic': critic_mean, }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
client=reverb.Client(replay_server_address), n_step=5, discount=0.99) # This connects to the created reverb server; also note that we use a transition # adder above so we'll tell the dataset function that so that it knows the type # of data that's coming out. dataset = datasets.make_reverb_dataset( table=replay_table_name, server_address=replay_server_address, batch_size=256, prefetch_size=True) # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.batch_concat observation_network = tf2_utils.to_sonnet_module(observation_network) # Create the target networks target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network,
def make_default_networks( environment_spec: mava_specs.MAEnvironmentSpec, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (256, 256, 256), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), shared_weights: bool = True, sigma: float = 0.3, archecture_type: ArchitectureType = ArchitectureType.feedforward, ) -> Mapping[str, types.TensorTransformation]: """Default networks for maddpg. Args: environment_spec (mava_specs.MAEnvironmentSpec): description of the action and observation spaces etc. for each agent in the system. policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): size of policy networks. Defaults to (256, 256, 256). critic_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): size of critic networks. Defaults to (512, 512, 256). shared_weights (bool, optional): whether agents should share weights or not. Defaults to True. sigma (float, optional): hyperparameters used to add Gaussian noise for simple exploration. Defaults to 0.3. archecture_type (ArchitectureType, optional): archecture used for agent networks. Can be feedforward or recurrent. Defaults to ArchitectureType.feedforward. Returns: Mapping[str, types.TensorTransformation]: returned agent networks. """ # Set Policy function and layer size if archecture_type == ArchitectureType.feedforward: policy_network_func = snt.Sequential elif archecture_type == ArchitectureType.recurrent: policy_networks_layer_sizes = (128, 128) policy_network_func = snt.DeepRNN specs = environment_spec.get_agent_specs() # Create agent_type specs if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() } if isinstance(critic_networks_layer_sizes, Sequence): critic_networks_layer_sizes = { key: critic_networks_layer_sizes for key in specs.keys() } observation_networks = {} policy_networks = {} critic_networks = {} for key in specs.keys(): # TODO (dries): Make specs[key].actions # return a list of specs for hybrid action space # Get total number of action dimensions from action spec. agent_act_spec = specs[key].actions if type(specs[key].actions) == DiscreteArray: num_actions = agent_act_spec.num_values minimum = [-1.0] * num_actions maximum = [1.0] * num_actions agent_act_spec = BoundedArray( shape=(num_actions, ), minimum=minimum, maximum=maximum, dtype="float32", name="actions", ) # Get total number of action dimensions from action spec. num_dimensions = np.prod(agent_act_spec.shape, dtype=int) # An optional network to process observations observation_network = tf2_utils.to_sonnet_module(tf.identity) # Create the policy network. if archecture_type == ArchitectureType.feedforward: policy_network = [ networks.LayerNormMLP(policy_networks_layer_sizes[key], activate_final=True), ] elif archecture_type == ArchitectureType.recurrent: policy_network = [ networks.LayerNormMLP(policy_networks_layer_sizes[key][:-1], activate_final=True), snt.LSTM(policy_networks_layer_sizes[key][-1]), ] policy_network += [ networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(agent_act_spec), ] # Add Gaussian noise for simple exploration. if sigma and sigma > 0.0: policy_network += [ networks.ClippedGaussian(sigma), networks.ClipToSpec(agent_act_spec), ] policy_network = policy_network_func(policy_network) # Create the critic network. critic_network = snt.Sequential([ # The multiplexer concatenates the observations/actions. networks.CriticMultiplexer(), networks.LayerNormMLP(list(critic_networks_layer_sizes[key]) + [1], activate_final=False), ]) observation_networks[key] = observation_network policy_networks[key] = policy_network critic_networks[key] = critic_network return { "policies": policy_networks, "critics": critic_networks, "observations": observation_networks, }
def make_acme_agent(environment_spec, residual_spec, obs_network_type, crop_frames, full_image_size, crop_margin_size, late_fusion, binary_grip_action=False, input_type=None, counter=None, logdir=None, agent_logger=None): """Initialize acme agent based on residual spec and agent flags.""" # TODO(minttu): Is environment_spec needed or could we use residual_spec? del logdir # Setting logdir for the learner ckpts not currently supported. obs_network = None if obs_network_type is not None: obs_network = agents.ObservationNet(network_type=obs_network_type, input_type=input_type, add_linear_layer=False, crop_frames=crop_frames, full_image_size=full_image_size, crop_margin_size=crop_margin_size, late_fusion=late_fusion) eval_policy = None if FLAGS.agent == 'MPO': agent_networks = networks.make_mpo_networks( environment_spec.actions, policy_init_std=FLAGS.policy_init_std, obs_network=obs_network) rl_agent = mpo.MPO( environment_spec=residual_spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], observation_network=agent_networks['observation'], discount=FLAGS.discount, batch_size=FLAGS.rl_batch_size, min_replay_size=FLAGS.min_replay_size, max_replay_size=FLAGS.max_replay_size, policy_optimizer=snt.optimizers.Adam(FLAGS.policy_rl), critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr), counter=counter, logger=agent_logger, checkpoint=FLAGS.write_acme_checkpoints, ) elif FLAGS.agent == 'DMPO': agent_networks = networks.make_dmpo_networks( environment_spec.actions, policy_layer_sizes=FLAGS.rl_policy_layer_sizes, critic_layer_sizes=FLAGS.rl_critic_layer_sizes, vmin=FLAGS.critic_vmin, vmax=FLAGS.critic_vmax, num_atoms=FLAGS.critic_num_atoms, policy_init_std=FLAGS.policy_init_std, binary_grip_action=binary_grip_action, obs_network=obs_network) # spec = residual_spec if obs_network is None else environment_spec spec = residual_spec rl_agent = dmpo.DistributionalMPO( environment_spec=spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], observation_network=agent_networks['observation'], discount=FLAGS.discount, batch_size=FLAGS.rl_batch_size, min_replay_size=FLAGS.min_replay_size, max_replay_size=FLAGS.max_replay_size, policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr), critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr), counter=counter, # logdir=logdir, logger=agent_logger, checkpoint=FLAGS.write_acme_checkpoints, ) # Learned policy without exploration. eval_policy = (tf.function( snt.Sequential([ tf_utils.to_sonnet_module(agent_networks['observation']), agent_networks['policy'], tf_networks.StochasticMeanHead() ]))) elif FLAGS.agent == 'D4PG': agent_networks = networks.make_d4pg_networks( residual_spec.actions, vmin=FLAGS.critic_vmin, vmax=FLAGS.critic_vmax, num_atoms=FLAGS.critic_num_atoms, policy_weights_init_scale=FLAGS.policy_weights_init_scale, obs_network=obs_network) # TODO(minttu): downscale action space to [-1, 1] to match clipped gaussian. rl_agent = d4pg.D4PG( environment_spec=residual_spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], observation_network=agent_networks['observation'], discount=FLAGS.discount, batch_size=FLAGS.rl_batch_size, min_replay_size=FLAGS.min_replay_size, max_replay_size=FLAGS.max_replay_size, policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr), critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr), sigma=FLAGS.policy_init_std, counter=counter, logger=agent_logger, checkpoint=FLAGS.write_acme_checkpoints, ) # Learned policy without exploration. eval_policy = tf.function( snt.Sequential([ tf_utils.to_sonnet_module(agent_networks['observation']), agent_networks['policy'] ])) else: raise NotImplementedError('Supported agents: MPO, DMPO, D4PG.') return rl_agent, eval_policy
def __init__( self, reward_objectives: Sequence[RewardObjective], qvalue_objectives: Sequence[QValueObjective], policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, num_samples: int, target_policy_update_period: int, target_critic_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = tf.identity, target_observation_network: types.TensorTransformation = tf.identity, policy_loss_module: Optional[losses.MultiObjectiveMPO] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, dual_optimizer: Optional[snt.Optimizer] = None, clipping: bool = True, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._num_samples = num_samples self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_policy_update_period = target_policy_update_period self._target_critic_update_period = target_critic_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # Store objectives self._reward_objectives = reward_objectives self._qvalue_objectives = qvalue_objectives if self._qvalue_objectives is None: self._qvalue_objectives = [] self._num_critic_heads = len(self._reward_objectives) # C self._objective_names = ([x.name for x in self._reward_objectives] + [x.name for x in self._qvalue_objectives]) self._policy_loss_module = policy_loss_module or losses.MultiObjectiveMPO( epsilons=[ losses.KLConstraint(name, _DEFAULT_EPSILON) for name in self._objective_names ], epsilon_mean=_DEFAULT_EPSILON_MEAN, epsilon_stddev=_DEFAULT_EPSILON_STDDEV, init_log_temperature=_DEFAULT_INIT_LOG_TEMPERATURE, init_log_alpha_mean=_DEFAULT_INIT_LOG_ALPHA_MEAN, init_log_alpha_stddev=_DEFAULT_INIT_LOG_ALPHA_STDDEV) # Check that ordering of objectives matches the policy_loss_module's if self._objective_names != list( self._policy_loss_module.objective_names): raise ValueError("Agent's ordering of objectives doesn't match " "the policy loss module's ordering of epsilons.") # Create the optimizers. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='dmpo_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'dual_optimizer': self._dual_optimizer, 'policy_loss_module': self._policy_loss_module, 'num_steps': self._num_steps, }) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': snt.Sequential([ self._target_observation_network, self._target_policy_network ]), }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp: float = None
def __init__(self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = lambda x: x, target_observation_network: types. TensorTransformation = lambda x: x, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, specified_path: str = None): # print('\033[94m I am sub_virtual acme d4pg learning\033[0m') """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). observation_network: an optional online network to process observations before the policy and the critic. target_observation_network: the target observation network. policy_optimizer: the optimizer to be applied to the DPG (policy) loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. clipping: whether to clip gradients by global norm. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter objects. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='d4pg_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }) # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/efc0f104-d4a6-11eb-9d04-04d4c40103a8/checkpoints/d4pg_learner/checkpoint/ckpt-1') # self._checkpointer._checkpoint.restore('/home/argsubt/acme/f397d4d6-edf2-11eb-a739-04d4c40103a8/checkpoints/d4pg_learner/ckpt-1') # print('\033[92mload checkpoints~\033[0m') # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/4346ec84-ee10-11eb-8185-04d4c40103a8/checkpoints/d4pg_learner/ckpt-532') self.specified_path = specified_path if self.specified_path != None: self._checkpointer._checkpoint.restore(self.specified_path) print('\033[92mspecified_path: ', str(self.specified_path), '\033[0m') critic_mean = snt.Sequential( [self._critic_network, acme_nets.StochasticMeanHead()]) self._snapshotter = tf2_savers.Snapshotter(objects_to_save={ 'policy': self._policy_network, 'critic': critic_mean, }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def make_default_networks( environment_spec: mava_specs.MAEnvironmentSpec, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (256, 256, 256), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), shared_weights: bool = True, ) -> Dict[str, snt.Module]: """Default networks for mappo. Args: environment_spec (mava_specs.MAEnvironmentSpec): description of the action and observation spaces etc. for each agent in the system. policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): size of policy networks. Defaults to (256, 256, 256). critic_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): size of critic networks. Defaults to (512, 512, 256). shared_weights (bool, optional): whether agents should share weights or not. Defaults to True. Raises: ValueError: Unknown action_spec type, if actions aren't DiscreteArray or BoundedArray. Returns: Dict[str, snt.Module]: returned agent networks. """ # Create agent_type specs. specs = environment_spec.get_agent_specs() if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() } if isinstance(critic_networks_layer_sizes, Sequence): critic_networks_layer_sizes = { key: critic_networks_layer_sizes for key in specs.keys() } observation_networks = {} policy_networks = {} critic_networks = {} for key in specs.keys(): # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.to_sonnet_module(tf.identity) # Note: The discrete case must be placed first as it inherits from BoundedArray. if isinstance(specs[key].actions, dm_env.specs.DiscreteArray): # discrete num_actions = specs[key].actions.num_values policy_network = snt.Sequential( [ networks.LayerNormMLP( tuple(policy_networks_layer_sizes[key]) + (num_actions,), activate_final=False, ), tf.keras.layers.Lambda( lambda logits: tfp.distributions.Categorical(logits=logits) ), ] ) elif isinstance(specs[key].actions, dm_env.specs.BoundedArray): # continuous num_actions = np.prod(specs[key].actions.shape, dtype=int) policy_network = snt.Sequential( [ networks.LayerNormMLP( policy_networks_layer_sizes[key], activate_final=True ), networks.MultivariateNormalDiagHead(num_dimensions=num_actions), networks.TanhToSpec(specs[key].actions), ] ) else: raise ValueError(f"Unknown action_spec type, got {specs[key].actions}.") critic_network = snt.Sequential( [ networks.LayerNormMLP( list(critic_networks_layer_sizes[key]) + [1], activate_final=False ), ] ) observation_networks[key] = observation_network policy_networks[key] = policy_network critic_networks[key] = critic_network return { "policies": policy_networks, "critics": critic_networks, "observations": observation_networks, }
def __init__( self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_policy_update_period: int = 100, target_critic_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, policy_loss_module: snt.Module = None, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, n_step: int = 5, num_samples: int = 20, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_policy_update_period: number of updates to perform before updating the target policy network. target_critic_update_period: number of updates to perform before updating the target critic network. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. policy_loss_module: configured MPO loss function for the policy optimization; defaults to sensible values on the control suite. See `acme/tf/losses/mpo.py` for more details. policy_optimizer: optimizer to be used on the policy. critic_optimizer: optimizer to be used on the critic. n_step: number of steps to squash into a single transition. num_samples: number of actions to sample when doing a Monte Carlo integration with respect to the policy. clipping: whether to clip gradients by global norm. logger: logging object used to write to logs. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), batch_size=batch_size, prefetch_size=prefetch_size, environment_spec=environment_spec, transition_adder=True) # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks before creating online/target network variables. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.StochasticSamplingHead(), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network=behavior_network, adder=adder) # Create optimizers. policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) # The learner updates the parameters (and initializes them). learner = learning.MPOLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_loss_module=policy_loss_module, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, num_samples=num_samples, target_policy_update_period=target_policy_update_period, target_critic_update_period=target_critic_update_period, dataset=dataset, logger=logger, counter=counter, checkpoint=checkpoint) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def learner( self, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" act_spec = self._environment_spec.actions obs_spec = self._environment_spec.observations # Create the networks to optimize (online) and target networks. online_networks = self._network_factory(act_spec) target_networks = self._network_factory(act_spec) # Make sure observation network is a Sonnet Module. observation_network = online_networks.get('observation', tf.identity) target_observation_network = target_networks.get( 'observation', tf.identity) observation_network = tf2_utils.to_sonnet_module(observation_network) target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # Get embedding spec and create observation network variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # Create variables. tf2_utils.create_variables(online_networks['policy'], [emb_spec]) tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) tf2_utils.create_variables(target_networks['policy'], [emb_spec]) tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger('learner', time_delta=self._log_every, steps_key='learner_steps') # Return the learning agent. return learning.DDPGLearner( policy_network=online_networks['policy'], critic_network=online_networks['critic'], observation_network=observation_network, target_policy_network=target_networks['policy'], target_critic_network=target_networks['critic'], target_observation_network=target_observation_network, discount=self._discount, target_update_period=self._target_update_period, dataset=dataset, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=self._clipping, counter=counter, logger=logger, )
def learner( self, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" act_spec = self._environment_spec.actions obs_spec = self._environment_spec.observations # Create online and target networks. online_networks = self._network_factory(act_spec) target_networks = self._network_factory(act_spec) # Make sure observation networks are Sonnet Modules. observation_network = online_networks.get('observation', tf.identity) observation_network = tf2_utils.to_sonnet_module(observation_network) online_networks['observation'] = observation_network target_observation_network = target_networks.get( 'observation', tf.identity) target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) target_networks['observation'] = target_observation_network # Get embedding spec and create observation network variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) tf2_utils.create_variables(online_networks['policy'], [emb_spec]) tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) tf2_utils.create_variables(target_networks['observation'], [obs_spec]) tf2_utils.create_variables(target_networks['policy'], [emb_spec]) tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=replay.server_address) dataset = dataset.batch(self._batch_size, drop_remainder=True) dataset = dataset.prefetch(self._prefetch_size) # Create a counter and logger for bookkeeping steps and performance. counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger('learner', time_delta=self._log_every, steps_key='learner_steps') # Create policy loss module if a factory is passed. if self._policy_loss_factory: policy_loss_module = self._policy_loss_factory() else: policy_loss_module = None # Return the learning agent. return learning.MPOLearner( policy_network=online_networks['policy'], critic_network=online_networks['critic'], observation_network=observation_network, target_policy_network=target_networks['policy'], target_critic_network=target_networks['critic'], target_observation_network=target_observation_network, discount=self._additional_discount, num_samples=self._num_samples, target_policy_update_period=self._target_policy_update_period, target_critic_update_period=self._target_critic_update_period, policy_loss_module=policy_loss_module, dataset=dataset, counter=counter, logger=logger)
def make_networks( environment_spec: mava_specs.MAEnvironmentSpec, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = ( 256, 256, 256, ), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), shared_weights: bool = True, ) -> Dict[str, snt.Module]: """Creates networks used by the agents.""" # Create agent_type specs. specs = environment_spec.get_agent_specs() if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() } if isinstance(critic_networks_layer_sizes, Sequence): critic_networks_layer_sizes = { key: critic_networks_layer_sizes for key in specs.keys() } observation_networks = {} policy_networks = {} critic_networks = {} for key in specs.keys(): # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.to_sonnet_module(tf.identity) # Note: The discrete case must be placed first as it inherits from BoundedArray. if isinstance(specs[key].actions, dm_env.specs.DiscreteArray): # discrete num_actions = specs[key].actions.num_values policy_network = snt.Sequential([ networks.LayerNormMLP( tuple(policy_networks_layer_sizes[key]) + (num_actions, ), activate_final=False, ), tf.keras.layers.Lambda(lambda logits: tfp.distributions. Categorical(logits=logits)), ]) elif isinstance(specs[key].actions, dm_env.specs.BoundedArray): # continuous num_actions = np.prod(specs[key].actions.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_networks_layer_sizes[key], activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions=num_actions), networks.TanhToSpec(specs[key].actions), ]) else: raise ValueError( f"Unknown action_spec type, got {specs[key].actions}.") critic_network = snt.Sequential([ networks.LayerNormMLP(critic_networks_layer_sizes[key], activate_final=True), networks.NearZeroInitializedLinear(1), ]) observation_networks[key] = observation_network policy_networks[key] = policy_network critic_networks[key] = critic_network return { "policies": policy_networks, "critics": critic_networks, "observations": observation_networks, }
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, n_step: int = 5, sigma: float = 0.3, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # pytype: disable=wrong-arg-types # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.ClippedGaussian(sigma), networks.ClipToSpec(act_spec), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(behavior_network, adder=adder) # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) # The learner updates the parameters (and initializes them). learner = learning.DDPGLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, target_update_period=target_update_period, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def make_networks( environment_spec: mava_specs.MAEnvironmentSpec, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = ( 256, 256, 256, ), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), shared_weights: bool = True, sigma: float = 0.3, ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agents.""" specs = environment_spec.get_agent_specs() # Create agent_type specs if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() } if isinstance(critic_networks_layer_sizes, Sequence): critic_networks_layer_sizes = { key: critic_networks_layer_sizes for key in specs.keys() } observation_networks = {} policy_networks = {} critic_networks = {} for key in specs.keys(): # Get total number of action dimensions from action spec. num_dimensions = np.prod(specs[key].actions.shape, dtype=int) # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.to_sonnet_module(tf.identity) # Create the policy network. policy_network = snt.Sequential( [ networks.LayerNormMLP( policy_networks_layer_sizes[key], activate_final=True ), networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(specs[key].actions), networks.ClippedGaussian(sigma), networks.ClipToSpec(specs[key].actions), ] ) # Create the critic network. critic_network = snt.Sequential( [ # The multiplexer concatenates the observations/actions. networks.CriticMultiplexer(), networks.LayerNormMLP( critic_networks_layer_sizes[key], activate_final=False ), snt.Linear(1), ] ) observation_networks[key] = observation_network policy_networks[key] = policy_network critic_networks[key] = critic_network return { "policies": policy_networks, "critics": critic_networks, "observations": observation_networks, }