def main(): results = { "results": [], "measure_reg": [], "measure_td": [], "measure_mc": [], } hps = { "opt": ARGS.opt, "env_name": ARGS.env_name, "lr": ARGS.learning_rate, "weight_decay": ARGS.weight_decay, "run": ARGS.run, } nhid = hps.get("nhid", 32) gamma = hps.get("gamma", 0.99) mbsize = hps.get("mbsize", 32) weight_decay = hps.get("weight_decay", 0) sample_near = hps.get("sample_near", "both") slice_size = hps.get("slice_size", 0) env_name = hps.get("env_name", "ms_pacman") clone_interval = hps.get("clone_interval", 10_000) reset_on_clone = hps.get("reset_on_clone", False) reset_opt_on_clone = hps.get("reset_opt_on_clone", False) max_clones = hps.get("max_clones", 2) target = hps.get("target", "last") # self, last, clones replay_type = hps.get("replay_type", "normal") # normal, prioritized final_epsilon = hps.get("final_epsilon", 0.05) num_exploration_steps = hps.get("num_exploration_steps", 500_000) lr = hps.get("lr", 1e-4) num_iterations = hps.get("num_iterations", 10_000_000) buffer_size = hps.get("buffer_size", 250_000) seed = hps.get("run", 0) + 1_642_559 # A large prime number hps["_seed"] = seed torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(env_name) num_act = env.num_actions def make_opt(theta): if hps.get("opt", "sgd") == "sgd": return torch.optim.SGD(theta, lr, weight_decay=weight_decay) elif hps["opt"] == "msgd": return torch.optim.SGD( theta, lr, momentum=hps.get("beta", 0.99), weight_decay=weight_decay) elif hps["opt"] == "rmsprop": return torch.optim.RMSprop(theta, lr, weight_decay=weight_decay) elif hps["opt"] == "adam": return torch.optim.Adam(theta, lr, weight_decay=weight_decay) else: raise ValueError(hps["opt"]) # Define model _Qarch, theta_q, Qf, _Qsemi = nn.build( nn.conv2d(4, nhid, 8, stride=4), # Input is 84x84 nn.conv2d(nhid, nhid * 2, 4, stride=2), nn.conv2d(nhid * 2, nhid * 2, 3), nn.flatten(), nn.hidden(nhid * 2 * 12 * 12, nhid * 16), nn.linear(nhid * 16, num_act), ) clone_theta_q = lambda: [i.detach().clone().requires_grad_() for i in theta_q] # Pretrained parameters theta_target = load_parameters_from_checkpoint() # (Same) Random parameters theta_regress = clone_theta_q() theta_qlearn = clone_theta_q() theta_mc = clone_theta_q() opt_regress = make_opt(theta_regress) opt_qlearn = make_opt(theta_qlearn) opt_mc = make_opt(theta_mc) # Define loss def sl1(a, b): d = a - b u = abs(d) s = d**2 m = (u < s).float() return u * m + s * (1 - m) td = lambda s, a, r, sp, t, w, tw=theta_q: sl1( r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(), Qf(s, w)[np.arange(len(a)), a.long()], ) obs = env.reset() replay_buffer = ReplayBuffer(seed, buffer_size, near_strategy=sample_near) total_reward = 0 last_end = 0 num_fill = buffer_size num_measure = 500 _t0 = t0 = t1 = t2 = t3 = t4 = time.time() tm0 = tm1 = tm2 = tm3 = time.time() ema_loss = 0 last_rewards = [0] print("Filling buffer") epsilon = final_epsilon replay_buffer.new_episode(obs, env.enumber % 2) while replay_buffer.idx < replay_buffer.size - 10: if rng.uniform(0, 1) < epsilon: action = rng.randint(0, num_act) else: action = Qf(tf(obs / 255.0).unsqueeze(0), theta_target).argmax().item() obsp, r, done, info = env.step(action) replay_buffer.add(obs, action, r, done, env.enumber % 2) obs = obsp if done: obs = env.reset() replay_buffer.new_episode(obs, env.enumber % 2) # Remove last episode from replay buffer, as it didn't end it = replay_buffer.idx curp = replay_buffer.p[it] while replay_buffer.p[it] == curp: replay_buffer._sumtree.set(it, 0) it -= 1 print(f'went from {replay_buffer.idx} to {it} when deleting states') print("Computing returns") replay_buffer.compute_values(lambda s: Qf(s, theta_regress), num_act) replay_buffer.compute_returns(gamma) replay_buffer.compute_reward_distances() print("Training regressions") losses_reg, losses_td, losses_mc = [], [], [] loss_reg_f = lambda x, w: sl1(Qf(x[0], w), Qf(x[0], theta_target)) loss_td_f = lambda x, w: td(*x[:-1], w, theta_target) loss_mc_f = lambda x, w: sl1( Qf(x[0], w)[np.arange(len(x[1])), x[1].long()], replay_buffer.g[x[-1]]) losses = { "reg": loss_reg_f, "td": loss_td_f, "mc": loss_mc_f, } measure_reg = Measures(theta_regress, losses, replay_buffer, results["measure_reg"], mbsize) measure_mc = Measures(theta_mc, losses, replay_buffer, results["measure_mc"], mbsize) measure_td = Measures(theta_qlearn, losses, replay_buffer, results["measure_td"], mbsize) for i in range(100_000): sample = replay_buffer.sample(mbsize) replay_buffer.compute_value_difference(sample, Qf(sample[0], theta_regress)) if i and not i % num_measure: measure_reg.pre(sample) measure_mc.pre(sample) measure_td.pre(sample) loss_reg = loss_reg_f(sample, theta_regress).mean() loss_reg.backward() losses_reg.append(loss_reg.item()) opt_regress.step() opt_regress.zero_grad() loss_td = loss_td_f(sample, theta_qlearn).mean() loss_td.backward() losses_td.append(loss_td.item()) opt_qlearn.step() opt_qlearn.zero_grad() loss_mc = loss_mc_f(sample, theta_mc).mean() loss_mc.backward() losses_mc.append(loss_mc.item()) opt_mc.step() opt_mc.zero_grad() replay_buffer.update_values(sample, Qf(sample[0], theta_regress)) if i and not i % num_measure: measure_reg.post() measure_td.post() measure_mc.post() if not i % 1000: print(i, loss_reg.item(), loss_td.item(), loss_mc.item())
def main(): device = torch.device(ARGS.device) mm.set_device(device) results = { "measure": [], "parameters": [], } seed = ARGS.run + 1_642_559 # A large prime number torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(ARGS.env_name) mbsize = ARGS.mbsize Lambda = ARGS.Lambda nhid = 32 num_measure = 1000 gamma = 0.99 clone_interval = ARGS.clone_interval num_iterations = ARGS.num_iterations num_Q_outputs = 1 # Model _Qarch, theta_q, Qf, _Qsemi = mm.build( mm.conv2d(4, nhid, 8, stride=4), # Input is 84x84 mm.conv2d(nhid, nhid * 2, 4, stride=2), mm.conv2d(nhid * 2, nhid * 2, 3), mm.flatten(), mm.hidden(nhid * 2 * 12 * 12, nhid * 16), mm.linear(nhid * 16, num_Q_outputs), ) clone_theta_q = lambda: [i.detach().clone() for i in theta_q] theta_target = clone_theta_q() opt = make_opt(ARGS.opt, theta_q, ARGS.learning_rate, ARGS.weight_decay) # Replay Buffer replay_buffer = ReplayBuffer(seed, ARGS.buffer_size) # Losses td = lambda s, a, r, sp, t, idx, w, tw: sl1( r + (1 - t.float()) * gamma * Qf(sp, tw)[:, 0].detach(), Qf(s, w)[:, 0], ) tdL = lambda s, a, r, sp, t, idx, w, tw: sl1( Qf(s, w)[:, 0], replay_buffer.LG[idx]) mc = lambda s, a, r, sp, t, idx, w, tw: sl1( Qf(s, w)[:, 0], replay_buffer.g[idx]) # Define metrics measure = Measures( theta_q, { "td": lambda x, w: td(*x, w, theta_target), "tdL": lambda x, w: tdL(*x, w, theta_target), "mc": lambda x, w: mc(*x, w, theta_target), }, replay_buffer, results["measure"], 32) # Get expert trajectories rand_classes = fill_buffer_with_expert(env, replay_buffer) # Compute initial values replay_buffer.compute_values(lambda s: Qf(s, theta_q), num_Q_outputs) replay_buffer.compute_returns(gamma) replay_buffer.compute_reward_distances() replay_buffer.compute_episode_boundaries() replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda, gamma) # Run policy evaluation for it in range(num_iterations): do_measure = not it % num_measure sample = replay_buffer.sample(mbsize) if do_measure: measure.pre(sample) replay_buffer.compute_value_difference(sample, Qf(sample[0], theta_q)) loss = tdL(*sample, theta_q, theta_target) loss = loss.mean() loss.backward() opt.step() opt.zero_grad() replay_buffer.update_values(sample, Qf(sample[0], theta_q)) if do_measure: measure.post() if it and clone_interval and it % clone_interval == 0: theta_target = clone_theta_q() replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda, gamma) if it and it % clone_interval == 0 or it == num_iterations - 1: ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)} ps.update({"step": it}) results["parameters"].append(ps) with open(f'results/td_lambda_{run}.pkl', 'wb') as f: pickle.dump(results, f)