コード例 #1
0
ファイル: SAC.py プロジェクト: henry-prior/jax-rl
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        discount: float = 0.99,
        tau: float = 0.005,
        policy_freq: int = 2,
        lr: float = 3e-4,
        entropy_tune: bool = True,
        seed: int = 0,
    ):
        self.rng = PRNGSequence(seed)

        actor_input_dim = (1, state_dim)

        actor_params = build_gaussian_policy_model(actor_input_dim, action_dim,
                                                   max_action, next(self.rng))
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [(1, state_dim), (1, action_dim)]

        critic_params = build_double_critic_model(critic_input_dim, init_rng)
        self.critic_target_params = build_double_critic_model(
            critic_input_dim, init_rng)
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        self.entropy_tune = entropy_tune

        log_alpha_params = build_constant_model(-3.5, next(self.rng))
        log_alpha_optimizer = optim.Adam(
            learning_rate=lr).create(log_alpha_params)
        self.log_alpha_optimizer = jax.device_put(log_alpha_optimizer)
        self.target_entropy = -action_dim

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_freq = policy_freq

        self.action_dim = action_dim

        self.total_it = 0
コード例 #2
0
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        discount: float = 0.99,
        lr: float = 3e-4,
        eps_eta: float = 0.1,
        eps_mu: float = 5e-4,
        eps_sig: float = 1e-5,
        target_freq: int = 250,
        seed: int = 0,
    ):
        self.rng = PRNGSequence(seed)

        init_rng = next(self.rng)

        actor_input_dim = (1, state_dim)

        actor_params = build_gaussian_policy_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        self.actor_target_params = build_gaussian_policy_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [(1, state_dim), (1, action_dim)]

        critic_params = build_double_critic_model(critic_input_dim, init_rng)
        self.critic_target_params = build_double_critic_model(
            critic_input_dim, init_rng
        )
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        mu_lagrange_params = build_constant_model(
            1.0, absolute=True, init_rng=next(self.rng)
        )
        mu_lagrange_optimizer = optim.Adam(learning_rate=lr).create(mu_lagrange_params)
        self.mu_lagrange_optimizer = jax.device_put(mu_lagrange_optimizer)

        sig_lagrange_params = build_constant_model(
            100.0, absolute=True, init_rng=next(self.rng)
        )
        sig_lagrange_optimizer = optim.Adam(learning_rate=lr).create(
            sig_lagrange_params
        )
        self.sig_lagrange_optimizer = jax.device_put(sig_lagrange_optimizer)

        self.temp = 1.0
        self.eps_eta = eps_eta
        self.eps_mu = eps_mu
        self.eps_sig = eps_sig

        self.max_action = max_action
        self.discount = discount
        self.target_freq = target_freq

        self.state_dim = state_dim
        self.action_dim = action_dim

        self.total_it = 0