samples_per_epoch = episode_len * episodes_per_epoch
batch_size = 100
update_every = 3

discount = 0.97

tau = 0.002  # best: ~0.005?

LR_0 = 0.00003  # best ~1e-5?
decay = 0.993

num_obs_rewards = 4

layers = [3, 80, 80, 1]
vn1 = jnn.init_network_params_He(layers)
vn1_targ = jnn.copy_network(vn1)
vn2 = jnn.init_network_params_He(layers)
vn2_targ = jnn.copy_network(vn2)

plotter = PP.PendulumValuePlotter2(n_grid=100, jupyter=True)
plotX = np.vstack(
    [np.cos(plotter.plotX1),
     np.sin(plotter.plotX1), plotter.plotX2])
stdizer = jnn.RewardStandardizer()

L = SL.Logger()

grad_log_names = []
for l_num in range(len(layers) - 1):
    for l_type in ["w", "b"]:
        for model_num in ["1", "2"]:
def greed_eps(t):
    return max(greed_eps_min, greed_eps_0 * greed_eps_decay**t)


discount = 0.97

tau = 0.002  # best: ~0.005?

LR_0 = 0.00003  # best ~1e-5?
decay = 0.993

num_lookahead = 4

layers = [4, 64, 64, 1]
qn1 = jnn.init_network_params_He(layers)
qn1_targ = jnn.copy_network(qn1)
# qn2 = jnn.init_network_params_He(layers)
# qn2_targ = jnn.copy_network(qn2)

plotter = PP.PendulumValuePlotter2(n_grid=100, jupyter=True)
plotX = np.vstack(
    [np.cos(plotter.plotX1),
     np.sin(plotter.plotX1), plotter.plotX2])
stdizer = jnn.RewardStandardizer()

L = SL.Logger()

grad_log_names = []
for l_num in range(len(layers) - 1):
    for l_type in ["w", "b"]:
        grad_log_names.append(f"qn1_d{l_type}{l_num}")