def __init__( self, state_dim: int, action_dim: int, max_action: float, discount: float = 0.99, tau: float = 0.005, policy_freq: int = 2, lr: float = 3e-4, entropy_tune: bool = True, seed: int = 0, ): self.rng = PRNGSequence(seed) actor_input_dim = (1, state_dim) actor_params = build_gaussian_policy_model(actor_input_dim, action_dim, max_action, next(self.rng)) actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [(1, state_dim), (1, action_dim)] critic_params = build_double_critic_model(critic_input_dim, init_rng) self.critic_target_params = build_double_critic_model( critic_input_dim, init_rng) critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params) self.critic_optimizer = jax.device_put(critic_optimizer) self.entropy_tune = entropy_tune log_alpha_params = build_constant_model(-3.5, next(self.rng)) log_alpha_optimizer = optim.Adam( learning_rate=lr).create(log_alpha_params) self.log_alpha_optimizer = jax.device_put(log_alpha_optimizer) self.target_entropy = -action_dim self.max_action = max_action self.discount = discount self.tau = tau self.policy_freq = policy_freq self.action_dim = action_dim self.total_it = 0
def __init__( self, state_dim: int, action_dim: int, max_action: float, discount: float = 0.99, lr: float = 3e-4, eps_eta: float = 0.1, eps_mu: float = 5e-4, eps_sig: float = 1e-5, target_freq: int = 250, seed: int = 0, ): self.rng = PRNGSequence(seed) init_rng = next(self.rng) actor_input_dim = (1, state_dim) actor_params = build_gaussian_policy_model( actor_input_dim, action_dim, max_action, init_rng ) self.actor_target_params = build_gaussian_policy_model( actor_input_dim, action_dim, max_action, init_rng ) actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [(1, state_dim), (1, action_dim)] critic_params = build_double_critic_model(critic_input_dim, init_rng) self.critic_target_params = build_double_critic_model( critic_input_dim, init_rng ) critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params) self.critic_optimizer = jax.device_put(critic_optimizer) mu_lagrange_params = build_constant_model( 1.0, absolute=True, init_rng=next(self.rng) ) mu_lagrange_optimizer = optim.Adam(learning_rate=lr).create(mu_lagrange_params) self.mu_lagrange_optimizer = jax.device_put(mu_lagrange_optimizer) sig_lagrange_params = build_constant_model( 100.0, absolute=True, init_rng=next(self.rng) ) sig_lagrange_optimizer = optim.Adam(learning_rate=lr).create( sig_lagrange_params ) self.sig_lagrange_optimizer = jax.device_put(sig_lagrange_optimizer) self.temp = 1.0 self.eps_eta = eps_eta self.eps_mu = eps_mu self.eps_sig = eps_sig self.max_action = max_action self.discount = discount self.target_freq = target_freq self.state_dim = state_dim self.action_dim = action_dim self.total_it = 0