DEFAULT_CONFIG["aux_loss_start_after_num_steps"] = 0 DEFAULT_CONFIG["adapt_policy_parameters"] = False DEFAULT_CONFIG["adapt_policy_parameters_options"] = { "unique_fraction_threshold": 0.9, "alternate_entropy_coeff": 0.025 } DataAugmentingTorchPolicy = build_torch_policy(name="DataAugmentingTorchPolicy", get_default_config=lambda: DEFAULT_CONFIG, loss_fn=data_augmenting_loss, stats_fn=data_augmenting_stats, extra_action_out_fn=vf_preds_fetches, postprocess_fn=postprocess_sample_batch, extra_grad_process_fn=my_apply_grad_clipping, before_init=setup_config, after_init=after_init_fn, mixins=[ KLCoeffMixin, ValueNetworkMixin, EntropyCoeffSchedule, LearningRateSchedule, ]) # Well, this is a bit of a hack, but oh well. def get_optimizer(policy, config={"opt_type": "adam"}): lr = policy.config["lr"] if hasattr(policy.model, "optimizer_options"): config = policy.model.optimizer_options
target_state_dict = self.target_model.state_dict() model_state_dict = { k: tau * model_state_dict[k] + (1 - tau) * v for k, v in target_state_dict.items() } self.target_model.load_state_dict(model_state_dict) def setup_late_mixins(policy, obs_space, action_space, config): policy.target_model = policy.target_model.to(policy.device) policy.model.log_alpha = policy.model.log_alpha.to(policy.device) policy.model.target_entropy = policy.model.target_entropy.to(policy.device) ComputeTDErrorMixin.__init__(policy) TargetNetworkMixin.__init__(policy) SACTorchPolicy = build_torch_policy( name="SACTorchPolicy", loss_fn=actor_critic_loss, get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG, stats_fn=stats, postprocess_fn=postprocess_trajectory, extra_grad_process_fn=apply_grad_clipping, optimizer_fn=optimizer_fn, validate_spaces=validate_spaces, after_init=setup_late_mixins, make_model_and_action_dist=build_sac_model_and_action_dist, mixins=[TargetNetworkMixin, ComputeTDErrorMixin], action_distribution_fn=action_distribution_fn, )
{SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]}) action_dist = policy.dist_class(logits) log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) # save the error in the policy object policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( log_probs.reshape(-1)) return policy.pi_err def postprocess_advantages(policy, sample_batch, other_agent_batches=None, episode=None): return compute_advantages(sample_batch, 0.0, policy.config["gamma"], use_gae=False) def pg_loss_stats(policy, batch_tensors): # the error is recorded when computing the loss return {"policy_loss": policy.pi_err.item()} PGTorchPolicy = build_torch_policy( name="PGTorchPolicy", get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, loss_fn=pg_torch_loss, stats_fn=pg_loss_stats, postprocess_fn=postprocess_advantages)
####################################################################################################### ##################################### Policy ##################################################### ####################################################################################################### import algorithms.drq.ppo.ppo_trainer NoAugPPOTorchPolicy = build_torch_policy( name="NoAugPPOTorchPolicy", loss_fn=ppo_surrogate_loss, postprocess_fn=postprocess_ppo_gae, make_model_and_action_dist=build_ppo_model_and_action_dist, # shared # get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, get_default_config=lambda: algorithms.drq.ppo.ppo_trainer.PPO_CONFIG, stats_fn=kl_and_loss_stats, extra_action_out_fn=vf_preds_fetches, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, after_init=setup_mixins, mixins=[ LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, ValueNetworkMixin ]) DrqPPOTorchPolicy = build_torch_policy( name="DrqPPOTorchPolicy", loss_fn=drq_ppo_surrogate_loss, postprocess_fn=postprocess_drq_ppo_gae, make_model_and_action_dist=build_ppo_model_and_action_dist, # shared
return self.kl_coeff_val def maml_optimizer_fn(policy, config): """ Workers use simple SGD for inner adaptation Meta-Policy uses Adam optimizer for meta-update """ if not config["worker_index"]: return torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) return torch.optim.SGD(policy.model.parameters(), lr=config["inner_lr"]) def setup_mixins(policy, obs_space, action_space, config): ValueNetworkMixin.__init__(policy, obs_space, action_space, config) KLCoeffMixin.__init__(policy, config) MAMLTorchPolicy = build_torch_policy( name="MAMLTorchPolicy", get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG, loss_fn=maml_loss, stats_fn=maml_stats, optimizer_fn=maml_optimizer_fn, extra_action_out_fn=vf_preds_fetches, postprocess_fn=postprocess_ppo_gae, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, after_init=setup_mixins, mixins=[KLCoeffMixin])
def dreamer_stats(policy, train_batch): return policy.stats_dict def dreamer_optimizer_fn(policy, config): model = policy.model encoder_weights = list(model.encoder.parameters()) decoder_weights = list(model.decoder.parameters()) reward_weights = list(model.reward.parameters()) dynamics_weights = list(model.dynamics.parameters()) actor_weights = list(model.actor.parameters()) critic_weights = list(model.value.parameters()) model_opt = torch.optim.Adam(encoder_weights + decoder_weights + reward_weights + dynamics_weights, lr=config["td_model_lr"]) actor_opt = torch.optim.Adam(actor_weights, lr=config["actor_lr"]) critic_opt = torch.optim.Adam(critic_weights, lr=config["critic_lr"]) return (model_opt, actor_opt, critic_opt) DreamerTorchPolicy = build_torch_policy( name="DreamerTorchPolicy", get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG, action_sampler_fn=action_sampler_fn, loss_fn=dreamer_loss, stats_fn=dreamer_stats, make_model_and_action_dist=build_dreamer_model, optimizer_fn=dreamer_optimizer_fn, extra_grad_process_fn=apply_grad_clipping)
def choose_optimizer(policy, config): if policy.config["opt_type"] == "adam": return torch.optim.Adam(params=policy.model.parameters(), lr=policy.cur_lr) else: return torch.optim.RMSProp(params=policy.model.parameters(), lr=policy.cur_lr, weight_decay=config["decay"], momentum=config["momentum"], eps=config["epsilon"]) def setup_mixins(policy, obs_space, action_space, config): EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) VTraceTorchPolicy = build_torch_policy( name="VTraceTorchPolicy", loss_fn=build_vtrace_loss, get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG, stats_fn=stats, postprocess_fn=postprocess_trajectory, extra_grad_process_fn=apply_grad_clipping, optimizer_fn=choose_optimizer, before_init=setup_mixins, mixins=[LearningRateSchedule, EntropyCoeffSchedule], get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])
dim=-1), 0.5) batch_size = int(loss.shape[0]) train_set_size = int(batch_size * policy.config["train_set_ratio"]) train_loss, validation_loss = \ torch.split(loss, (train_set_size, batch_size - train_set_size), dim=0) policy.dynamics_train_loss = torch.mean(train_loss) policy.dynamics_validation_loss = torch.mean(validation_loss) return policy.dynamics_train_loss def stats_fn(policy, train_batch): return { "dynamics_train_loss": policy.dynamics_train_loss, "dynamics_validation_loss": policy.dynamics_validation_loss, } def torch_optimizer(policy, config): return torch.optim.Adam(policy.dynamics_model.parameters(), lr=config["lr"]) DYNATorchPolicy = build_torch_policy( name="DYNATorchPolicy", loss_fn=dyna_torch_loss, get_default_config=lambda: ray.rllib.agents.dyna.dyna.DEFAULT_CONFIG, stats_fn=stats_fn, optimizer_fn=torch_optimizer, make_model_and_action_dist=make_model_and_dist, )
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None) # not called automatically by any rllib logic, call this in your training script or a trainer callback def update_lr(self, timesteps_total): print(f"cur lr {self.cur_lr}") self.cur_lr = self.lr_schedule.value(timesteps_total) for opt in self._optimizers: for p in opt.param_groups: p["lr"] = self.cur_lr def setup_mixins(policy, obs_space, action_space, config): ManualLearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) NFSPTorchAveragePolicy = build_torch_policy( name="NFSPAveragePolicy", extra_action_out_fn=behaviour_logits_fetches, loss_fn=build_supervised_learning_loss, get_default_config=lambda: grl.algos.nfsp_rllib.nfsp.DEFAULT_CONFIG, make_model_and_action_dist=build_avg_model_and_distribution, action_sampler_fn=action_sampler, before_init=setup_mixins, extra_learn_fetches_fn=lambda policy: {"sl_loss": policy.loss}, optimizer_fn=sgd_optimizer, stats_fn=build_avg_policy_stats, mixins=[ManualLearningRateSchedule, SafeSetWeightsPolicyMixin], # action_distribution_fn=get_distribution_inputs_and_class, )
self._value = value def setup_mixins(policy, obs_space, action_space, config): ValueNetworkMixin.__init__(policy, obs_space, action_space, config) KLCoeffMixin.__init__(policy, config) EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) def optimizer(policy, config): """Custom PyTorch optimizer to use.""" return torch.optim.Adam(policy.model.parameters(), lr=config["lr"], eps=1e-5) CustomPPOTorchPolicy = build_torch_policy( name="CustomPPOTorchPolicy", get_default_config=lambda: DEFAULT_CONFIG, loss_fn=ppo_surrogate_loss, stats_fn=kl_and_loss_stats, optimizer_fn=optimizer, extra_action_out_fn=vf_preds_fetches, postprocess_fn=postprocess_ppo_gae, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, after_init=setup_mixins, mixins=[KLCoeffMixin, ValueNetworkMixin])
actions, logp = \ policy.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore ) if policy.requires_tupling: actions = actions.unsqueeze(1).tolist() logp = logp.unsqueeze(1) return actions, logp, state def extra_grad_process(policy, opt, loss): if policy.log_stats: return {**{"classification_loss": loss.item()}, **policy.stats_dict} else: return {"classification_loss": loss.item()} AVGPolicy = build_torch_policy( name="AVGTorchPolicy", loss_fn=build_loss, action_sampler_fn=action_sampler_fn, make_model_and_action_dist=make_model_and_action_dist, get_default_config=lambda: DEFAULT_CONFIG, extra_grad_process_fn=extra_grad_process, optimizer_fn=optimizer_fn, )
####################################################################################################### ##################################### Policy ##################################################### ####################################################################################################### # hack to avoid cycle imports import algorithms.curl.rainbow.rainbow_trainer CurlRainbowTorchPolicy = build_torch_policy( name="CurlRainbowTorchPolicy", # loss updates shifted to policy.learn_on_batch loss_fn=None, make_model_and_action_dist=build_q_model_and_distribution, action_distribution_fn=get_distribution_inputs_and_class, optimizer_fn=optimizer_fn, # shared # get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, get_default_config=lambda: algorithms.curl.rainbow.rainbow_trainer.RAINBOW_CONFIG, stats_fn=build_q_stats, postprocess_fn=postprocess_nstep_and_prio, extra_grad_process_fn=grad_process_and_td_error_fn, extra_action_out_fn=extra_action_out_fn, before_init=setup_early_mixins, # added curl mixin after_init=after_init, mixins=[ TargetNetworkMixin, ComputeTDErrorMixin, LearningRateSchedule, CurlMixin ])
def pg_loss_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: """Returns the calculated loss in a stats dict. Args: policy (Policy): The Policy object. train_batch (SampleBatch): The data used for training. Returns: Dict[str, TensorType]: The stats dict. """ return { # `pi_err` (the loss) is stored inside `pg_torch_loss()`. "policy_loss": policy.pi_err.item(), } # Build a child class of `TFPolicy`, given the extra options: # - trajectory post-processing function (to calculate advantages) # - PG loss function PGTorchPolicy = build_torch_policy( name="PGTorchPolicy", get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, loss_fn=pg_torch_loss, stats_fn=pg_loss_stats, postprocess_fn=post_process_advantages, view_requirements_fn=view_requirements_fn, )
def postprocess_fn_add_next_actions_for_sarsa(policy: Policy, batch: SampleBatch, other_agent=None, episode=None) -> SampleBatch: """Add next_actions to SampleBatch for SARSA training""" if policy.config["slateq_strategy"] == "SARSA": if not batch["dones"][-1]: raise RuntimeError( "Expected a complete episode in each sample batch. " f"But this batch is not: {batch}.") batch["next_actions"] = np.roll(batch["actions"], -1, axis=0) return batch SlateQTorchPolicy = build_torch_policy( name="SlateQTorchPolicy", get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG, # build model, loss functions, and optimizers make_model_and_action_dist=build_slateq_model_and_distribution, optimizer_fn=build_slateq_optimizers, loss_fn=build_slateq_losses, # define how to act action_sampler_fn=action_sampler_fn, # post processing batch sampled data postprocess_fn=postprocess_fn_add_next_actions_for_sarsa, )
} policy.num_params = sum(np.prod(s) for s in policy.param_shapes.values()) def make_model_and_action_dist(policy, observation_space, action_space, config): # Policy network. dist_class, dist_dim = ModelCatalog.get_action_dist( action_space, config["model"], # model_options dist_type="deterministic", framework="torch") model = ModelCatalog.get_model_v2(observation_space, action_space, num_outputs=dist_dim, model_config=config["model"], framework="torch") # Make all model params not require any gradients. for p in model.parameters(): p.requires_grad = False return model, dist_class ESTorchPolicy = build_torch_policy( name="ESTorchPolicy", loss_fn=None, get_default_config=lambda: ray.rllib.agents.es.es.DEFAULT_CONFIG, before_init=before_init, after_init=after_init, make_model_and_action_dist=make_model_and_action_dist)
train_batch[SampleBatch.EPS_ID], policy.global_timestep) def episode_adversarial_stats(policy, train_batch): stats = kl_and_loss_stats(policy, train_batch) stats.update(policy.model.metrics()) return stats EpisodeAdversarialTorchPolicy = build_torch_policy( name="EpisodeAdversarialPolicy", get_default_config=lambda: DEFAULT_CONFIG, loss_fn=episode_adversarial_loss, stats_fn=episode_adversarial_stats, extra_action_out_fn=vf_preds_fetches, postprocess_fn=postprocess_ppo_gae, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, after_init=setup_mixins, mixins=[KLCoeffMixin, ValueNetworkMixin]) def get_policy_class(config): return EpisodeAdversarialTorchPolicy EpisodeAdversarialTrainer = build_trainer( name="episode_adversarial", default_config=DEFAULT_CONFIG, default_policy=EpisodeAdversarialTorchPolicy,
return model, dist_class # Sample actions def option_critic_action_sampler_fn(policy, model, input_dict, obs_space, action_space, config): action = action_logp = None return action, action_logp # After rest of policy is setup, start with initial option and delib cost def option_critic_after_init(policy, observation_space, action_space, config): pass A2OCTorchPolicy = build_torch_policy( name="A2OCTorchPolicy", get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, loss_fn=option_critic_loss, stats_fn=option_critic_stats, extra_action_out_fn=option_critic_extra_action_out, extra_grad_process_fn=option_critic_gradient_process, optimizer_fn=option_critic_optimizer, make_model_and_action_dist=option_critic_make_model_and_action_dist, ) class A2OCTorchPolicyClass(TorchPolicy): def __init__(self, obs_space, action_space, config, model, loss=option_critic_loss, action_distribution_class=None, action_sampler_fn=option_critic_action_sampler_fn): super().__init__(obs_space, action_space, config, model=model, loss=loss,
# Compute the error (Square/Huber). td_error = q_t_selected - q_t_selected_target.detach() loss = torch.mean(huber_loss(td_error)) # save TD error as an attribute for outside access policy.td_error = td_error return loss def extra_action_out_fn(policy, input_dict, state_batches, model, action_dist): """Adds q-values to action out dict.""" return {"q_values": policy.q_values} def setup_late_mixins(policy, obs_space, action_space, config): TargetNetworkMixin.__init__(policy, obs_space, action_space, config) SimpleQTorchPolicy = build_torch_policy( name="SimpleQPolicy", loss_fn=build_q_losses, get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, extra_action_out_fn=extra_action_out_fn, after_init=setup_late_mixins, make_model_and_action_dist=build_q_model_and_distribution, mixins=[TargetNetworkMixin], action_distribution_fn=get_distribution_inputs_and_class, stats_fn=lambda policy, config: {"td_error": policy.td_error}, )
from ray.rllib.policy.torch_policy_template import build_torch_policy parser = argparse.ArgumentParser() parser.add_argument("--iters", type=int, default=200) def policy_gradient_loss(policy, batch_tensors): logits, _ = policy.model( {SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]}) action_dist = policy.dist_class(logits, policy.model) log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) return -batch_tensors[SampleBatch.REWARDS].dot(log_probs) # <class 'ray.rllib.policy.torch_policy_template.MyTorchPolicy'> MyTorchPolicy = build_torch_policy(name="MyTorchPolicy", loss_fn=policy_gradient_loss) # <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'> MyTrainer = build_trainer( name="MyCustomTrainer", default_policy=MyTorchPolicy, ) if __name__ == "__main__": ray.init() args = parser.parse_args() tune.run(MyTrainer, stop={"training_iteration": args.iters}, config={ "env": "CartPole-v0", "num_workers": 2,
if policy.config["grad_clip"]: for param_group in optimizer.param_groups: # Make sure we only pass params with grad != None into torch # clip_grad_norm_. Would fail otherwise. params = list( filter(lambda p: p.grad is not None, param_group["params"])) if params: grad_gnorm = nn.utils.clip_grad_norm_( params, policy.config["grad_clip"]) if isinstance(grad_gnorm, torch.Tensor): grad_gnorm = grad_gnorm.cpu().numpy() info["grad_gnorm"] = grad_gnorm return info COMATorchPolicy = build_torch_policy( name="COMATorchPolicy", make_model_and_action_dist=make_model_and_action_dist, postprocess_fn=compute_target, optimizer_fn=make_coma_optimizers, validate_spaces=validate_spaces, get_default_config=lambda: coma.trainer.DEFAULT_CONFIG, view_requirements_fn=view_requirements_fn, mixins=[ TargetNetworkMixin, ], after_init=setup_late_mixins, extra_grad_process_fn=apply_grad_clipping, stats_fn=stats, loss_fn=loss_fn)
def grad_process_and_td_error_fn(policy, optimizer, loss): # Clip grads if configured. info = apply_grad_clipping(policy, optimizer, loss) # Add td-error to info dict. info["td_error"] = policy.q_loss.td_error return info def extra_action_out_fn(policy, input_dict, state_batches, model, action_dist): return {"q_values": policy.q_values} DQNTorchPolicy = build_torch_policy( name="DQNTorchPolicy", loss_fn=build_q_losses, get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, make_model_and_action_dist=build_q_model_and_distribution, action_distribution_fn=get_distribution_inputs_and_class, stats_fn=build_q_stats, postprocess_fn=postprocess_nstep_and_prio, optimizer_fn=adam_optimizer, extra_grad_process_fn=grad_process_and_td_error_fn, extra_action_out_fn=extra_action_out_fn, before_init=setup_early_mixins, after_init=after_init, mixins=[ TargetNetworkMixin, ComputeTDErrorMixin, LearningRateSchedule, ])
def choose_optimizer(policy, config): if policy.config["opt_type"] == "adam": return torch.optim.Adam(policy.model.parameters(), lr=config['lr']) else: return torch.optim.RMSprop(policy.model.parameters(), lr=config['lr'], eps=config["epsilon"], weight_decay=config["decay"], momentum=config["momentum"]) class ValueNetworkMixin(object): def _value(self, obs): with self.lock: obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) vf = self.model({"obs": obs}, []) return vf[2][0].detach().cpu().numpy() VTraceTorchPolicy = build_torch_policy( name="VTraceTorchPolicy", get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG, loss_fn=build_vtrace_loss, stats_fn=stats, postprocess_fn=postprocess_trajectory, optimizer_fn=choose_optimizer, extra_action_out_fn=add_behaviour_logits, extra_grad_process_fn=apply_grad_clipping, mixins=[ValueNetworkMixin])
out[SampleBatch.VF_PREDS] = policy.model.value_function() return out def setup_early_mixins(policy, obs_space, action_space, config): LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) def setup_late_mixins(policy, obs_space, action_space, config): KLCoeffMixin.__init__(policy, config) ValueNetworkMixin.__init__(policy, obs_space, action_space, config) TargetNetworkMixin.__init__(policy, obs_space, action_space, config) AsyncPPOTorchPolicy = build_torch_policy( name="AsyncPPOTorchPolicy", loss_fn=build_appo_surrogate_loss, stats_fn=stats, postprocess_fn=postprocess_trajectory, extra_action_out_fn=add_values, extra_grad_process_fn=apply_grad_clipping, optimizer_fn=choose_optimizer, before_init=setup_early_mixins, after_init=setup_late_mixins, make_model=build_appo_model, mixins=[ LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin, ValueNetworkMixin ], get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])
return policy.total_loss def stats(policy, train_batch): return { "policy_loss": policy.p_loss, "vf_loss": policy.v_loss, "total_loss": policy.total_loss, "vf_explained_var": policy.explained_variance, } def setup_mixins(policy, obs_space, action_space, config): # Create a var. policy.ma_adv_norm = torch.tensor([100.0], dtype=torch.float32, requires_grad=False) # Setup Value branch of our NN. ValueNetworkMixin.__init__(policy) MARWILTorchPolicy = build_torch_policy( name="MARWILTorchPolicy", loss_fn=marwil_loss, get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG, stats_fn=stats, postprocess_fn=postprocess_advantages, after_init=setup_mixins, mixins=[ValueNetworkMixin])
info = {} if policy.config["grad_clip"]: total_norm = nn.utils.clip_grad_norm_(policy.model.parameters(), policy.config["grad_clip"]) info["grad_gnorm"] = total_norm return info def torch_optimizer(policy, config): return torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) class ValueNetworkMixin(object): def _value(self, obs): with self.lock: obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) _ = self.model({"obs": obs}, [], [1]) return self.model.value_function().detach().cpu().numpy().squeeze() A3CTorchPolicy = build_torch_policy( name="A3CTorchPolicy", get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, loss_fn=actor_critic_loss, stats_fn=loss_and_entropy_stats, postprocess_fn=add_advantages, extra_action_out_fn=model_value_predictions, extra_grad_process_fn=apply_grad_clipping, optimizer_fn=torch_optimizer, mixins=[ValueNetworkMixin])
target_model_vars = self.target_model.variables() assert len(model_vars) == len(target_model_vars), \ (model_vars, target_model_vars) for var, var_target in zip(model_vars, target_model_vars): var_target.data = tau * var.data + \ (1.0 - tau) * var_target.data def setup_late_mixins(policy, obs_space, action_space, config): ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss) TargetNetworkMixin.__init__(policy) DDPGTorchPolicy = build_torch_policy( name="DDPGTorchPolicy", loss_fn=ddpg_actor_critic_loss, get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, stats_fn=build_ddpg_stats, postprocess_fn=postprocess_nstep_and_prio, extra_grad_process_fn=gradients_fn, optimizer_fn=make_ddpg_optimizers, before_init=before_init_fn, after_init=setup_late_mixins, action_distribution_fn=get_distribution_inputs_and_class, make_model_and_action_dist=build_ddpg_models_and_action_dist, apply_gradients_fn=apply_gradients_fn, mixins=[ TargetNetworkMixin, ComputeTDErrorMixin, ])
policy.model.target_entropy = policy.model.target_entropy.to(policy.device) ComputeTDErrorMixin.__init__(policy) TargetNetworkMixin.__init__(policy) CurlMixin.__init__(policy) ####################################################################################################### ##################################### Policy ##################################################### ####################################################################################################### # hack to avoid cycle imports import algorithms.curl.sac.sac_trainer CurlSACTorchPolicy = build_torch_policy( name="CurlSACTorchPolicy", # loss updates shifted to policy.learn_on_batch loss_fn=None, # get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG, get_default_config=lambda: algorithms.curl.sac.sac_trainer.SAC_CONFIG, stats_fn=stats, # called in a torch.no_grad scope, calls loss func again to update td error postprocess_fn=postprocess_trajectory, # will clip grad in learn_on_batch if grad_clip is not None in config extra_grad_process_fn=apply_grad_clipping, optimizer_fn=optimizer_fn, after_init=setup_late_mixins, make_model_and_action_dist=build_curl_sac_model_and_action_dist, mixins=[TargetNetworkMixin, ComputeTDErrorMixin, CurlMixin], action_distribution_fn=curl_action_distribution_fn, )
model_cls = DiscreteLinearModelThompsonSampling elif exploration_config["type"] == UCB_PATH: if isinstance(original_space, spaces.Dict): assert "item" in original_space.spaces, \ "Cannot find 'item' key in observation space" model_cls = ParametricLinearModelUCB else: model_cls = DiscreteLinearModelUCB model = model_cls( obs_space, action_space, logit_dim, config["model"], name="LinearModel") return model, dist_class def init_cum_regret(policy, *args): policy.regrets = [] BanditPolicy = build_torch_policy( name="BanditPolicy", get_default_config=lambda: DEFAULT_CONFIG, loss_fn=None, after_init=init_cum_regret, make_model_and_action_dist=make_model_and_action_dist, optimizer_fn=lambda policy, config: None, # Pass a dummy optimizer mixins=[BanditPolicyOverrides])
else: def value(ob, prev_action, prev_reward, *state): return 0.0 self._value = value def setup_mixins(policy, obs_space, action_space, config): ValueNetworkMixin.__init__(policy, obs_space, action_space, config) KLCoeffMixin.__init__(policy, config) EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) PPOTorchPolicy = build_torch_policy( name="PPOTorchPolicy", get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, loss_fn=ppo_surrogate_loss, stats_fn=kl_and_loss_stats, extra_action_out_fn=vf_preds_fetches, postprocess_fn=postprocess_ppo_gae, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, after_init=setup_mixins, mixins=[ LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, ValueNetworkMixin ])
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: """Call all mixin classes' constructors before SimpleQTorchPolicy initialization. Args: policy (Policy): The Policy object. obs_space (gym.spaces.Space): The Policy's observation space. action_space (gym.spaces.Space): The Policy's action space. config (TrainerConfigDict): The Policy's config. """ TargetNetworkMixin.__init__(policy, obs_space, action_space, config) SimpleQTorchPolicyPatched = build_torch_policy( name="SimpleQPolicy", loss_fn=build_q_losses, get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, after_init=setup_late_mixins, extra_action_out_fn=_simple_dqn_extra_action_out_fn, make_model_and_action_dist=_build_q_model_and_distribution, mixins=[TargetNetworkMixin, SafeSetWeightsPolicyMixin], action_distribution_fn=get_distribution_inputs_and_class, extra_learn_fetches_fn=lambda policy: { "td_error": policy.td_error, "loss": policy.loss }, )