Пример #1
0
    def __init__(
        self,
        obs_dim,
        act_dim,
        action_max=2,
        memory_size=1e6,
        batch_size=256,
        td_steps=1,
        discount=0.99,
        LR=3 * 1e-4,
        tau=0.005,
        update_interval=1,
        grad_steps_per_update=1,
        seed=0,
        alpha_0=0.2,
        reward_standardizer=jnn.RewardStandardizer(),
        state_transformer=None,
    ):

        self.obs_dim = obs_dim
        self.act_dim = act_dim
        self.action_max = action_max
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.td_steps = td_steps
        self.gamma = discount
        self.LR = LR
        self.tau = tau
        self.update_interval = update_interval
        self.grad_steps_per_update = grad_steps_per_update
        self.seed = seed

        self.rngkey = jax.random.PRNGKey(seed)

        q_params, q_fn = create_q_net(obs_dim, act_dim, self.new_key())
        self.q_fn = q_fn
        self.q = q_params
        self.q_targ = stu.copy_network(q_params)

        pi_params, pi_fn = create_pi_net(obs_dim, act_dim, self.new_key())
        self.pi_fn = pi_fn
        self.pi = pi_params

        self.H_target = -act_dim
        self.memory_train = ReplayBuffer(obs_dim,
                                         act_dim,
                                         memory_size,
                                         reward_steps=td_steps,
                                         batch_size=batch_size)
        self.reward_standardizer = reward_standardizer
        self.state_transformer = state_transformer
        self.state_transformer_batched = jax.jit(jax.vmap(state_transformer))

        self.alpha = alpha_0
LR_0 = 0.00003  # best ~1e-5?
decay = 0.993

num_obs_rewards = 4

layers = [3, 80, 80, 1]
vn1 = jnn.init_network_params_He(layers)
vn1_targ = jnn.copy_network(vn1)
vn2 = jnn.init_network_params_He(layers)
vn2_targ = jnn.copy_network(vn2)

plotter = PP.PendulumValuePlotter2(n_grid=100, jupyter=True)
plotX = np.vstack(
    [np.cos(plotter.plotX1),
     np.sin(plotter.plotX1), plotter.plotX2])
stdizer = jnn.RewardStandardizer()

L = SL.Logger()

grad_log_names = []
for l_num in range(len(layers) - 1):
    for l_type in ["w", "b"]:
        for model_num in ["1", "2"]:
            grad_log_names.append(f"d{l_type}{l_num}_vn{model_num}")

for i in range(num_epochs):
    epoch_memory = None
    for j in range(episodes_per_epoch):
        episode = PU.make_n_step_sin_cos_traj_episode(episode_len,
                                                      num_obs_rewards, stdizer,
                                                      params)