예제 #1
0
def ppo_init(policy: Policy, obs_space: gym.spaces.Space,
             action_space: gym.spaces.Space,
             config: TrainerConfigDict) -> None:
    """ TODO: Write documentation.
    """
    # Call base implementation
    setup_mixins(policy, obs_space, action_space, config)

    # Add previous observation in viewer requirements for CAPS loss computation
    # TODO: Remove update of `policy.model.view_requirements` after ray fix
    caps_view_requirements = {
        "_prev_obs":
        ViewRequirement(data_col="obs",
                        space=obs_space,
                        shift=-1,
                        used_for_compute_actions=False)
    }
    policy.model.view_requirements.update(caps_view_requirements)
    policy.view_requirements.update(caps_view_requirements)

    # Initialize extra loss
    policy._mean_symmetric_policy_loss = 0.0
    policy._mean_temporal_caps_loss = 0.0
    policy._mean_spatial_caps_loss = 0.0
    policy._mean_global_caps_loss = 0.0
예제 #2
0
def setup_mixins_override(policy: Policy, obs_space: gym.spaces.Space,
                          action_space: gym.spaces.Space,
                          config: TrainerConfigDict) -> None:
    """Have to initialize the custom ValueNetworkMixin
    """
    setup_mixins(policy, obs_space, action_space, config)
    ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
예제 #3
0
def setup_mixins_and_mcts(policy: Policy, obs_space: gym.spaces.Space,
                          action_space: gym.spaces.Space,
                          config: TrainerConfigDict) -> None:

    setup_mixins(policy, obs_space, action_space, config)

    # assumed discrete action space
    policy.mcts = MCTS(policy.model, policy.config["mcts_param"],
                       action_space.n, policy.device)
예제 #4
0
    def __init__(self, obs_space, action_space, config, model, loss,
                 action_distribution_class):
        # update policy attr for loss calculation
        # self.framework = config['framework'] = 'torch'
        # self.kl_coeff = config['kl_coeff']
        # self.kl_target = config['kl_target']
        # self.entropy_coeff = config['entropy_coeff']
        # self.cur_lr = config['lr']
        # setup ._value() for gae computation
        # self.setup_value(config)
        # self.dist_class, logit_dim = ModelCatalog.get_action_dist(
        #             action_space, config["model"], framework='torch')
        assert getattr(
            self, 'model',
            None), f'The agent\' model has to be initialized before this.'
        # self.model = ModelCatalog.get_model_v2(
        #             obs_space=obs_space,
        #             action_space=action_space,
        #             num_outputs=logit_dim,
        #             model_config=config["model"],
        #             framework='torch')

        super().__init__(obs_space,
                         action_space,
                         config,
                         model=self.model,
                         loss=ppo_surrogate_loss,
                         action_distribution_class=self.dist_class,
                         max_seq_len=config['model']['max_seq_len'])

        # Merge Model's view requirements into Policy's.
        self.view_requirements.update(self.model.view_requirements)
        # init mixins
        setup_mixins(self, obs_space, action_space, config)
        # Perform test runs through postprocessing- and loss functions.
        self._initialize_loss_from_dummy_batch(
            auto_remove_unneeded_view_reqs=True,
            stats_fn=kl_and_loss_stats,
        )
        self.global_timestep = 0
def after_init_fn(policy, obs_space, action_space, config):
    setup_mixins(policy, obs_space, action_space, config)

    # Reward normalization.
    rew_norm_opt = policy.config["reward_normalization_options"]
    if rew_norm_opt["mode"] == "running_mean_std":
        policy.reward_norm_stats = RunningStat(max_count=1000)
    elif rew_norm_opt["mode"] == "running_return":
        policy.reward_norm_stats = ExpWeightedMovingAverageStat(alpha=rew_norm_opt["alpha"])
    elif rew_norm_opt["mode"] == "none":
        pass
    else:
        raise ValueError(f"Unsupported reward norm mode: {rew_norm_opt['mode']}")

    # Gradient clipping.
    grad_clip_opt = policy.config["grad_clip_options"]
    if grad_clip_opt["mode"] == "constant":
        pass
    elif grad_clip_opt["mode"] == "adaptive":
        policy.prev_gradient_norms = collections.deque(maxlen=grad_clip_opt["adaptive_buffer_size"])
    else:
        raise ValueError(f"Unsupported grad clip mode: {grad_clip_opt['mode']}")

    # Automatic data augmentation.
    auto_drac_opt = policy.config["auto_drac_options"]
    if auto_drac_opt["active"]:
        policy.always_use_transforms = auto_drac_opt["always_use_transforms"]
        policy.choose_between_transforms = []
        for t in auto_drac_opt["choose_between_transforms"]:
            if isinstance(t, list):
                t = tuple(t)
            policy.choose_between_transforms.append(t)
        if auto_drac_opt["learner_class"] == "ucb":
            policy.transform_selector = UCBLearner(
                policy.choose_between_transforms,
                **auto_drac_opt["ucb_options"],
                verbose=False,
            )
        else:
            raise NotImplementedError(f"Learner not implemented: {auto_drac_opt['learner_class']}")