def __init__( self, obs_dim, act_dim, action_max=2, memory_size=1e6, batch_size=256, td_steps=1, discount=0.99, LR=3 * 1e-4, tau=0.005, update_interval=1, grad_steps_per_update=1, seed=0, alpha_0=0.2, reward_standardizer=jnn.RewardStandardizer(), state_transformer=None, ): self.obs_dim = obs_dim self.act_dim = act_dim self.action_max = action_max self.memory_size = memory_size self.batch_size = batch_size self.td_steps = td_steps self.gamma = discount self.LR = LR self.tau = tau self.update_interval = update_interval self.grad_steps_per_update = grad_steps_per_update self.seed = seed self.rngkey = jax.random.PRNGKey(seed) q_params, q_fn = create_q_net(obs_dim, act_dim, self.new_key()) self.q_fn = q_fn self.q = q_params self.q_targ = stu.copy_network(q_params) pi_params, pi_fn = create_pi_net(obs_dim, act_dim, self.new_key()) self.pi_fn = pi_fn self.pi = pi_params self.H_target = -act_dim self.memory_train = ReplayBuffer(obs_dim, act_dim, memory_size, reward_steps=td_steps, batch_size=batch_size) self.reward_standardizer = reward_standardizer self.state_transformer = state_transformer self.state_transformer_batched = jax.jit(jax.vmap(state_transformer)) self.alpha = alpha_0
LR_0 = 0.00003 # best ~1e-5? decay = 0.993 num_obs_rewards = 4 layers = [3, 80, 80, 1] vn1 = jnn.init_network_params_He(layers) vn1_targ = jnn.copy_network(vn1) vn2 = jnn.init_network_params_He(layers) vn2_targ = jnn.copy_network(vn2) plotter = PP.PendulumValuePlotter2(n_grid=100, jupyter=True) plotX = np.vstack( [np.cos(plotter.plotX1), np.sin(plotter.plotX1), plotter.plotX2]) stdizer = jnn.RewardStandardizer() L = SL.Logger() grad_log_names = [] for l_num in range(len(layers) - 1): for l_type in ["w", "b"]: for model_num in ["1", "2"]: grad_log_names.append(f"d{l_type}{l_num}_vn{model_num}") for i in range(num_epochs): epoch_memory = None for j in range(episodes_per_epoch): episode = PU.make_n_step_sin_cos_traj_episode(episode_len, num_obs_rewards, stdizer, params)