def ppo_init(policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: """ TODO: Write documentation. """ # Call base implementation setup_mixins(policy, obs_space, action_space, config) # Add previous observation in viewer requirements for CAPS loss computation # TODO: Remove update of `policy.model.view_requirements` after ray fix caps_view_requirements = { "_prev_obs": ViewRequirement(data_col="obs", space=obs_space, shift=-1, used_for_compute_actions=False) } policy.model.view_requirements.update(caps_view_requirements) policy.view_requirements.update(caps_view_requirements) # Initialize extra loss policy._mean_symmetric_policy_loss = 0.0 policy._mean_temporal_caps_loss = 0.0 policy._mean_spatial_caps_loss = 0.0 policy._mean_global_caps_loss = 0.0
def setup_mixins_override(policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: """Have to initialize the custom ValueNetworkMixin """ setup_mixins(policy, obs_space, action_space, config) ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
def setup_mixins_and_mcts(policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: setup_mixins(policy, obs_space, action_space, config) # assumed discrete action space policy.mcts = MCTS(policy.model, policy.config["mcts_param"], action_space.n, policy.device)
def __init__(self, obs_space, action_space, config, model, loss, action_distribution_class): # update policy attr for loss calculation # self.framework = config['framework'] = 'torch' # self.kl_coeff = config['kl_coeff'] # self.kl_target = config['kl_target'] # self.entropy_coeff = config['entropy_coeff'] # self.cur_lr = config['lr'] # setup ._value() for gae computation # self.setup_value(config) # self.dist_class, logit_dim = ModelCatalog.get_action_dist( # action_space, config["model"], framework='torch') assert getattr( self, 'model', None), f'The agent\' model has to be initialized before this.' # self.model = ModelCatalog.get_model_v2( # obs_space=obs_space, # action_space=action_space, # num_outputs=logit_dim, # model_config=config["model"], # framework='torch') super().__init__(obs_space, action_space, config, model=self.model, loss=ppo_surrogate_loss, action_distribution_class=self.dist_class, max_seq_len=config['model']['max_seq_len']) # Merge Model's view requirements into Policy's. self.view_requirements.update(self.model.view_requirements) # init mixins setup_mixins(self, obs_space, action_space, config) # Perform test runs through postprocessing- and loss functions. self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, stats_fn=kl_and_loss_stats, ) self.global_timestep = 0
def after_init_fn(policy, obs_space, action_space, config): setup_mixins(policy, obs_space, action_space, config) # Reward normalization. rew_norm_opt = policy.config["reward_normalization_options"] if rew_norm_opt["mode"] == "running_mean_std": policy.reward_norm_stats = RunningStat(max_count=1000) elif rew_norm_opt["mode"] == "running_return": policy.reward_norm_stats = ExpWeightedMovingAverageStat(alpha=rew_norm_opt["alpha"]) elif rew_norm_opt["mode"] == "none": pass else: raise ValueError(f"Unsupported reward norm mode: {rew_norm_opt['mode']}") # Gradient clipping. grad_clip_opt = policy.config["grad_clip_options"] if grad_clip_opt["mode"] == "constant": pass elif grad_clip_opt["mode"] == "adaptive": policy.prev_gradient_norms = collections.deque(maxlen=grad_clip_opt["adaptive_buffer_size"]) else: raise ValueError(f"Unsupported grad clip mode: {grad_clip_opt['mode']}") # Automatic data augmentation. auto_drac_opt = policy.config["auto_drac_options"] if auto_drac_opt["active"]: policy.always_use_transforms = auto_drac_opt["always_use_transforms"] policy.choose_between_transforms = [] for t in auto_drac_opt["choose_between_transforms"]: if isinstance(t, list): t = tuple(t) policy.choose_between_transforms.append(t) if auto_drac_opt["learner_class"] == "ucb": policy.transform_selector = UCBLearner( policy.choose_between_transforms, **auto_drac_opt["ucb_options"], verbose=False, ) else: raise NotImplementedError(f"Learner not implemented: {auto_drac_opt['learner_class']}")