def fill_buffer_with_expert(replay_buffer, env_name, epsilon=0.01): mbsize = ARGS.mbsize envs = [AtariEnv(env_name) for i in range(mbsize)] num_act = envs[0].num_actions nhid = 32 _, theta_q, Qf, _ = nn.build( nn.conv2d(4, nhid, 8, stride=4), # Input is 84x84 nn.conv2d(nhid, nhid * 2, 4, stride=2), nn.conv2d(nhid * 2, nhid * 2, 3), nn.flatten(), nn.hidden(nhid * 2 * 12 * 12, nhid * 16), nn.linear(nhid * 16, num_act), ) theta_q_trained = load_parameters_from_checkpoint() if ARGS.expert_is_self: theta_expert = theta_q_trained else: expert_id = { 'ms_pacman':457, 'asterix':403, 'seaquest':428}[env_name] with open(f'checkpoints/dqn_model_{expert_id}.pkl', "rb") as f: theta_expert = pickle.load(f) theta_expert = [tf(i) for i in theta_expert] obs = [i.reset() for i in envs] trajs = [list() for i in range(mbsize)] enumbers = list(range(mbsize)) replay_buffer.ram = torch.zeros([replay_buffer.size, 128], dtype=torch.uint8, device=replay_buffer.device) while True: mbobs = tf(obs) / 255 greedy_actions = Qf(mbobs, theta_expert).argmax(1) random_actions = np.random.randint(0, num_act, mbsize) actions = [ j if np.random.random() < epsilon else i for i, j in zip(greedy_actions, random_actions) ] for i, (e, a) in enumerate(zip(envs, actions)): obsp, r, done, _ = e.step(a) trajs[i].append([obs[i], int(a), float(r), int(done), e.getRAM() + 0]) obs[i] = obsp if replay_buffer.idx + len(trajs[i]) + 4 >= replay_buffer.size: # We're done! return Qf, theta_q_trained replay_buffer.new_episode(trajs[i][0][0], enumbers[i] % 2) for s, a, r, d, ram in trajs[i]: replay_buffer.ram[replay_buffer.idx] = tint(ram) replay_buffer.add(s, a, r, d, enumbers[i] % 2) trajs[i] = [] obs[i] = envs[i].reset() enumbers[i] = max(enumbers) + 1
def main(): results = { "episode": [], "measure": [], "parameters": [], } hps = { "opt": ARGS.opt, "env_name": ARGS.env_name, "lr": ARGS.learning_rate, "weight_decay": ARGS.weight_decay, "run": ARGS.run, } start_step = ARGS.start_step nhid = hps.get("nhid", 32) gamma = hps.get("gamma", 0.99) mbsize = hps.get("mbsize", 32) weight_decay = hps.get("weight_decay", 0) sample_near = hps.get("sample_near", "both") slice_size = hps.get("slice_size", 0) env_name = hps.get("env_name", "ms_pacman") clone_interval = hps.get("clone_interval", 10_000) reset_on_clone = hps.get("reset_on_clone", False) reset_opt_on_clone = hps.get("reset_opt_on_clone", False) max_clones = hps.get("max_clones", 2) replay_type = hps.get("replay_type", "normal") # normal, prioritized final_epsilon = hps.get("final_epsilon", 0.05) num_exploration_steps = hps.get("num_exploration_steps", 500_000) lr = hps.get("lr", 1e-4) num_iterations = hps.get("num_iterations", 10_000_000) buffer_size = hps.get("buffer_size", 250_000) seed = hps.get("run", 0) + 1_642_559 # A large prime number hps["_seed"] = seed torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(env_name) num_act = env.num_actions # Define model _Qarch, theta_q, Qf, _Qsemi = nn.build( nn.conv2d(4, nhid, 8, stride=4), # Input is 84x84 nn.conv2d(nhid, nhid * 2, 4, stride=2), nn.conv2d(nhid * 2, nhid * 2, 3), nn.flatten(), nn.hidden(nhid * 2 * 12 * 12, nhid * 16), nn.linear(nhid * 16, num_act), ) def make_opt(): if hps.get("opt", "sgd") == "sgd": return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "msgd": return torch.optim.SGD(theta_q, lr, momentum=hps.get("beta", 0.99), weight_decay=weight_decay) elif hps["opt"] == "rmsprop": return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "adam": return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay) else: raise ValueError(hps["opt"]) opt = make_opt() clone_theta_q = lambda: [i.detach().clone() for i in theta_q] def copy_theta_q_to_target(): for i in range(len(theta_q)): frozen_theta_q[i] = theta_q[i].detach().clone() # Define loss def sl1(a, b): d = a - b u = abs(d) s = d**2 m = (u < s).float() return u * m + s * (1 - m) td = lambda s, a, r, sp, t, w, tw=theta_q: sl1( r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(), Qf(s, w)[np.arange(len(a)), a.long()], ) obs = env.reset() if replay_type == "normal": replay_buffer = ReplayBuffer(seed, buffer_size, near_strategy=sample_near) elif replay_type == "prioritized": replay_buffer = PrioritizedExperienceReplay(seed, buffer_size, near_strategy=sample_near) total_reward = 0 last_end = 0 num_fill = 200000 num_measure = 500 _t0 = t0 = t1 = t2 = t3 = t4 = time.time() tm0 = tm1 = tm2 = tm3 = time.time() ema_loss = 0 last_rewards = [0] measure = Measures() print("Filling buffer") if start_step < num_exploration_steps: epsilon = 1 - (start_step / num_exploration_steps) * (1 - final_epsilon) else: epsilon = final_epsilon for it in range(num_fill): if start_step == 0: action = rng.randint(0, num_act) else: if rng.uniform(0, 1) < epsilon: action = rng.randint(0, num_act) else: action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item() obsp, r, done, info = env.step(action) replay_buffer.add(obs, action, r, done, env.enumber % 2) if replay_type == "prioritized": replay_buffer.set_last_priority( td( tf(obs / 255.0).unsqueeze(0), tint([action]), r, tf(obsp / 255.0).unsqueeze(0), tf([done]), theta_q, theta_q, )) obs = obsp if done: obs = env.reset() past_theta = [clone_theta_q()] for it in range(start_step, num_iterations): do_measure = not it % num_measure eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60 if it and it % 100_000 == 0 or it == num_iterations - 1: ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)} ps.update({"step": it}) results["parameters"].append(ps) if it % 10_000 == 0: print( it, f"{(t1 - t0)*1000:.2f}, {(t2 - t1)*1000:.2f}, {(t3 - t2)*1000:.2f}, {(t4 - t3)*1000:.2f},", f"{(tm1 - tm0)*1000:.2f}, {(tm3 - tm2)*1000:.2f},", f"{int(eta//60):2d}h{int(eta%60):02d}m left", f":: {ema_loss:.5f}, last 10 rewards: {np.mean(last_rewards):.2f}", )
def main(): device = torch.device(ARGS.device) nn.set_device(device) results = { "episode": [], "measure": [], "parameters": [], } hps = { "opt": ARGS.opt, "env_name": ARGS.env_name, "lr": ARGS.learning_rate, "weight_decay": ARGS.weight_decay, "run": ARGS.run, } nhid = hps.get("nhid", 32) gamma = hps.get("gamma", 0.99) mbsize = ARGS.mbsize weight_decay = hps.get("weight_decay", 0) sample_near = hps.get("sample_near", "both") slice_size = hps.get("slice_size", 0) env_name = hps.get("env_name", "ms_pacman") clone_interval = ARGS.clone_interval reset_on_clone = hps.get("reset_on_clone", False) reset_opt_on_clone = hps.get("reset_opt_on_clone", False) max_clones = hps.get("max_clones", 2) replay_type = hps.get("replay_type", "normal") # normal, prioritized final_epsilon = hps.get("final_epsilon", 0.05) num_exploration_steps = hps.get("num_exploration_steps", 500_000) Lambda = ARGS.Lambda lr = hps.get("lr", 1e-4) num_iterations = hps.get("num_iterations", 10_000_000) buffer_size = ARGS.buffer_size seed = hps.get("run", 0) + 1_642_559 # A large prime number hps["_seed"] = seed torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(env_name) num_act = env.num_actions # Define model _Qarch, theta_q, Qf, _Qsemi = nn.build( nn.conv2d(4, nhid, 8, stride=4), # Input is 84x84 nn.conv2d(nhid, nhid * 2, 4, stride=2), nn.conv2d(nhid * 2, nhid * 2, 3), nn.flatten(), nn.hidden(nhid * 2 * 12 * 12, nhid * 16), nn.linear(nhid * 16, num_act), ) def make_opt(): if hps.get("opt", "sgd") == "sgd": return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "msgd": return torch.optim.SGD(theta_q, lr, momentum=hps.get("beta", 0.99), weight_decay=weight_decay) elif hps["opt"] == "rmsprop": return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "adam": return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay) else: raise ValueError(hps["opt"]) opt = make_opt() clone_theta_q = lambda: [i.detach().clone() for i in theta_q] def copy_theta_q_to_target(): for i in range(len(theta_q)): frozen_theta_q[i] = theta_q[i].detach().clone() # Define loss def sl1(a, b): d = a - b u = abs(d) s = d**2 m = (u < s).float() return u * m + s * (1 - m) td = lambda x: sl1( x.r + (1 - x.t.float()) * gamma * Qf(x.sp, past_theta[0]).max(1)[0].detach(), Qf(x.s, theta_q)[np.arange(len(x.a)), x.a.long()], ) tdQL = lambda x: sl1( Qf(x.s, theta_q)[np.arange(len(x.a)), x.a.long()], x.lg) mc = lambda x: sl1(Qf(x.s, theta_q).max(1)[0], x.g) past_theta = [clone_theta_q()] replay_buffer = ReplayBufferV2(seed, buffer_size, lambda s: Qf(s, theta_q), lambda s: Qf(s, past_theta[0]).max(1)[0], Lambda, gamma) total_reward = 0 last_end = 0 num_fill = buffer_size // 2 num_measure = 500 _t0 = t0 = t1 = t2 = t3 = t4 = time.time() tm0 = tm1 = tm2 = tm3 = time.time() ema_loss = 0 last_rewards = [0] measure = Measures(theta_q, { "td": td, "tdQL": tdQL, "mc": mc, }, replay_buffer, results["measure"], 32) obs = env.reset() for it in range(num_fill): action = rng.randint(0, num_act) obsp, r, done, info = env.step(action) replay_buffer.add(obs, action, r, done) obs = obsp if done: print(it) obs = env.reset() for it in range(num_iterations): do_measure = not it % num_measure eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60 if it and it % 100_000 == 0 or it == num_iterations - 1: ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)} ps.update({"step": it}) results["parameters"].append(ps) if it < num_exploration_steps: epsilon = 1 - (it / num_exploration_steps) * (1 - final_epsilon) else: epsilon = final_epsilon if rng.uniform(0, 1) < epsilon: action = rng.randint(0, num_act) else: action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item() obsp, r, done, info = env.step(action) total_reward += r replay_buffer.add(obs, action, r, done) obs = obsp if done: obs = env.reset() results["episode"].append({ "end": it, "start": last_end, "total_reward": total_reward }) last_end = it last_rewards = [total_reward] + last_rewards[:10] total_reward = 0 sample = replay_buffer.sample(mbsize) with torch.no_grad(): v_before = Qf(sample.s, theta_q).detach() loss = tdQL(sample) if do_measure: tm0 = time.time() measure.pre(sample) tm1 = time.time() loss = loss.mean() loss.backward() opt.step() opt.zero_grad() with torch.no_grad(): v_after = Qf(sample.s, theta_q).detach() replay_buffer.compute_value_difference(sample, v_before, v_after) if do_measure: tm2 = time.time() measure.post() tm3 = time.time() t4 = time.time() if it and clone_interval and it % clone_interval == 0: past_theta = [clone_theta_q()] #+ past_theta[:max_clones - 1] replay_buffer.recompute_lambda_returns() #exp_results["loss"].append(loss.item()) ema_loss = 0.999 * ema_loss + 0.001 * loss.item()
def main(): device = torch.device(ARGS.device) mm.set_device(device) results = { "measure": [], "parameters": [], } seed = ARGS.run + 1_642_559 # A large prime number torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(ARGS.env_name) mbsize = ARGS.mbsize Lambda = ARGS.Lambda nhid = 32 num_measure = 1000 gamma = 0.99 clone_interval = ARGS.clone_interval num_iterations = ARGS.num_iterations num_Q_outputs = 1 # Model _Qarch, theta_q, Qf, _Qsemi = mm.build( mm.conv2d(4, nhid, 8, stride=4), # Input is 84x84 mm.conv2d(nhid, nhid * 2, 4, stride=2), mm.conv2d(nhid * 2, nhid * 2, 3), mm.flatten(), mm.hidden(nhid * 2 * 12 * 12, nhid * 16), mm.linear(nhid * 16, num_Q_outputs), ) clone_theta_q = lambda: [i.detach().clone() for i in theta_q] theta_target = clone_theta_q() opt = make_opt(ARGS.opt, theta_q, ARGS.learning_rate, ARGS.weight_decay) # Replay Buffer replay_buffer = ReplayBuffer(seed, ARGS.buffer_size) # Losses td = lambda s, a, r, sp, t, idx, w, tw: sl1( r + (1 - t.float()) * gamma * Qf(sp, tw)[:, 0].detach(), Qf(s, w)[:, 0], ) tdL = lambda s, a, r, sp, t, idx, w, tw: sl1( Qf(s, w)[:, 0], replay_buffer.LG[idx]) mc = lambda s, a, r, sp, t, idx, w, tw: sl1( Qf(s, w)[:, 0], replay_buffer.g[idx]) # Define metrics measure = Measures( theta_q, { "td": lambda x, w: td(*x, w, theta_target), "tdL": lambda x, w: tdL(*x, w, theta_target), "mc": lambda x, w: mc(*x, w, theta_target), }, replay_buffer, results["measure"], 32) # Get expert trajectories rand_classes = fill_buffer_with_expert(env, replay_buffer) # Compute initial values replay_buffer.compute_values(lambda s: Qf(s, theta_q), num_Q_outputs) replay_buffer.compute_returns(gamma) replay_buffer.compute_reward_distances() replay_buffer.compute_episode_boundaries() replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda, gamma) # Run policy evaluation for it in range(num_iterations): do_measure = not it % num_measure sample = replay_buffer.sample(mbsize) if do_measure: measure.pre(sample) replay_buffer.compute_value_difference(sample, Qf(sample[0], theta_q)) loss = tdL(*sample, theta_q, theta_target) loss = loss.mean() loss.backward() opt.step() opt.zero_grad() replay_buffer.update_values(sample, Qf(sample[0], theta_q)) if do_measure: measure.post() if it and clone_interval and it % clone_interval == 0: theta_target = clone_theta_q() replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda, gamma) if it and it % clone_interval == 0 or it == num_iterations - 1: ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)} ps.update({"step": it}) results["parameters"].append(ps) with open(f'results/td_lambda_{run}.pkl', 'wb') as f: pickle.dump(results, f)