def main(): device = torch.device(ARGS.device) mm.set_device(device) results = { "measure": [], "parameters": [], "args": ARGS, } print(ARGS) seed = ARGS.run + 1_642_559 # A large prime number torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(ARGS.env_name) mbsize = ARGS.mbsize nhid = 32 num_measure = 1000 gamma = 0.99 tau = 0.01 clone_interval = ARGS.clone_interval num_iterations = ARGS.num_iterations num_Q_outputs = env.num_actions if ARGS.loss_func != 'rand' else ARGS.num_rand_classes # Model act = torch.nn.LeakyReLU() Qf = torch.nn.Sequential(torch.nn.Conv2d(4, nhid, 8, stride=4, padding=4), act, torch.nn.Conv2d(nhid, nhid*2, 4, stride=2,padding=2), act, torch.nn.Conv2d(nhid*2, nhid*2, 3,padding=1), act, torch.nn.Flatten(), torch.nn.Linear(nhid*2*12*12, nhid*16), act, torch.nn.Linear(nhid*16, num_Q_outputs)) Qf.to(device) Qf.apply(init_weights) if ARGS.loss_func == 'nfdqn': Qf_target = Qf else: Qf_target = copy.deepcopy(Qf) Qf = extend(Qf) opt = make_opt(ARGS.opt, Qf.parameters(), ARGS.learning_rate, ARGS.weight_decay) # Replay Buffer replay_buffer = ReplayBufferV2(seed, ARGS.buffer_size, value_callback=lambda s: Qf(s), Lambda=0) td = lambda x: sl1( x.r + (1 - x.t.float()) * gamma * Qf_target(x.sp).max(1)[0].detach(), Qf(x.s)[np.arange(len(x.a)), x.a.long()], ) sarsa = lambda x: sl1( x.r + ((1 - x.t.float()) * gamma * Qf_target(x.sp)[np.arange(len(x.ap)), x.ap.long()].detach()), Qf(x.s)[np.arange(len(x.a)), x.a.long()], ) mc = lambda x: sl1( Qf(x.s).max(1)[0], x.g) if ARGS.loss_func == 'rand': raise ValueError('fixme Qf') def rand_nll(s, a, r, sp, t, idx, w, tw): return F.cross_entropy(Qf(s, w), tint(rand_classes[idx]), reduce=False) def rand_acc(s, a, r, sp, t, idx, w, tw): return (Qf(s, w).argmax(1) != tint(rand_classes[idx])).float() # Define metrics measure = Measures( theta_q, { "rand_nll": lambda x, w: rand_nll(*x, w, theta_target), "rand_acc": lambda x, w: rand_acc(*x, w, theta_target), }, replay_buffer, results["measure"], 32) loss_func = rand_nll else: # Define metrics measure = Measures( list(Qf.parameters()), { "td": td, "func": lambda x: Qf(x.s).max(1).values, #"sarsa": sarsa, #"mc": mc, }, replay_buffer, results["measure"], 32, lambda x: Qf(x.s), Qf) loss_func = { "sarsa": sarsa, "qlearn": td, "mc": mc, 'ddqn': td, 'nfdqn': td, }[ARGS.loss_func] # Get expert trajectories fill_buffer_with_expert(env, replay_buffer) # Run policy evaluation for it in tqdm(range(num_iterations), smoothing=0): do_measure = not it % num_measure sample = replay_buffer.sample(mbsize) if do_measure: measure.pre(sample) #v_before = Qf(sample[0]) opt.zero_grad() loss = loss_func(sample) loss = loss.mean() loss.backward() opt.step() #replay_buffer.update_values(sample, v_before, Qf(sample[0], theta_q)) if do_measure: measure.post() if it and clone_interval and it % clone_interval == 0: if ARGS.loss_func in ['td']: Qf_target = copy.deepcopy(Qf) if ARGS.loss_func in ['ddqn']: for target_param, param in zip(Qf_target.parameters(), Qf.parameters()): target_param.data.copy_(tau * param + (1 - tau) * target_param) if it and it % clone_interval == 0 and False or it == num_iterations - 1: ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(Qf.parameters())} ps.update({"step": it}) results["parameters"].append(ps) with open(f'results/pol_eval_{ARGS.run}.pkl', 'wb') as f: pickle.dump(results, f)
def main(): device = torch.device(ARGS.device) nn.set_device(device) results = { "episode": [], "measure": [], "parameters": [], } hps = { "opt": ARGS.opt, "env_name": ARGS.env_name, "lr": ARGS.learning_rate, "weight_decay": ARGS.weight_decay, "run": ARGS.run, } nhid = hps.get("nhid", 32) gamma = hps.get("gamma", 0.99) mbsize = ARGS.mbsize weight_decay = hps.get("weight_decay", 0) sample_near = hps.get("sample_near", "both") slice_size = hps.get("slice_size", 0) env_name = hps.get("env_name", "ms_pacman") clone_interval = ARGS.clone_interval reset_on_clone = hps.get("reset_on_clone", False) reset_opt_on_clone = hps.get("reset_opt_on_clone", False) max_clones = hps.get("max_clones", 2) replay_type = hps.get("replay_type", "normal") # normal, prioritized final_epsilon = hps.get("final_epsilon", 0.05) num_exploration_steps = hps.get("num_exploration_steps", 500_000) Lambda = ARGS.Lambda lr = hps.get("lr", 1e-4) num_iterations = hps.get("num_iterations", 10_000_000) buffer_size = ARGS.buffer_size seed = hps.get("run", 0) + 1_642_559 # A large prime number hps["_seed"] = seed torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(env_name) num_act = env.num_actions # Define model _Qarch, theta_q, Qf, _Qsemi = nn.build( nn.conv2d(4, nhid, 8, stride=4), # Input is 84x84 nn.conv2d(nhid, nhid * 2, 4, stride=2), nn.conv2d(nhid * 2, nhid * 2, 3), nn.flatten(), nn.hidden(nhid * 2 * 12 * 12, nhid * 16), nn.linear(nhid * 16, num_act), ) def make_opt(): if hps.get("opt", "sgd") == "sgd": return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "msgd": return torch.optim.SGD(theta_q, lr, momentum=hps.get("beta", 0.99), weight_decay=weight_decay) elif hps["opt"] == "rmsprop": return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "adam": return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay) else: raise ValueError(hps["opt"]) opt = make_opt() clone_theta_q = lambda: [i.detach().clone() for i in theta_q] def copy_theta_q_to_target(): for i in range(len(theta_q)): frozen_theta_q[i] = theta_q[i].detach().clone() # Define loss def sl1(a, b): d = a - b u = abs(d) s = d**2 m = (u < s).float() return u * m + s * (1 - m) td = lambda x: sl1( x.r + (1 - x.t.float()) * gamma * Qf(x.sp, past_theta[0]).max(1)[0].detach(), Qf(x.s, theta_q)[np.arange(len(x.a)), x.a.long()], ) tdQL = lambda x: sl1( Qf(x.s, theta_q)[np.arange(len(x.a)), x.a.long()], x.lg) mc = lambda x: sl1(Qf(x.s, theta_q).max(1)[0], x.g) past_theta = [clone_theta_q()] replay_buffer = ReplayBufferV2(seed, buffer_size, lambda s: Qf(s, theta_q), lambda s: Qf(s, past_theta[0]).max(1)[0], Lambda, gamma) total_reward = 0 last_end = 0 num_fill = buffer_size // 2 num_measure = 500 _t0 = t0 = t1 = t2 = t3 = t4 = time.time() tm0 = tm1 = tm2 = tm3 = time.time() ema_loss = 0 last_rewards = [0] measure = Measures(theta_q, { "td": td, "tdQL": tdQL, "mc": mc, }, replay_buffer, results["measure"], 32) obs = env.reset() for it in range(num_fill): action = rng.randint(0, num_act) obsp, r, done, info = env.step(action) replay_buffer.add(obs, action, r, done) obs = obsp if done: print(it) obs = env.reset() for it in range(num_iterations): do_measure = not it % num_measure eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60 if it and it % 100_000 == 0 or it == num_iterations - 1: ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)} ps.update({"step": it}) results["parameters"].append(ps) if it < num_exploration_steps: epsilon = 1 - (it / num_exploration_steps) * (1 - final_epsilon) else: epsilon = final_epsilon if rng.uniform(0, 1) < epsilon: action = rng.randint(0, num_act) else: action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item() obsp, r, done, info = env.step(action) total_reward += r replay_buffer.add(obs, action, r, done) obs = obsp if done: obs = env.reset() results["episode"].append({ "end": it, "start": last_end, "total_reward": total_reward }) last_end = it last_rewards = [total_reward] + last_rewards[:10] total_reward = 0 sample = replay_buffer.sample(mbsize) with torch.no_grad(): v_before = Qf(sample.s, theta_q).detach() loss = tdQL(sample) if do_measure: tm0 = time.time() measure.pre(sample) tm1 = time.time() loss = loss.mean() loss.backward() opt.step() opt.zero_grad() with torch.no_grad(): v_after = Qf(sample.s, theta_q).detach() replay_buffer.compute_value_difference(sample, v_before, v_after) if do_measure: tm2 = time.time() measure.post() tm3 = time.time() t4 = time.time() if it and clone_interval and it % clone_interval == 0: past_theta = [clone_theta_q()] #+ past_theta[:max_clones - 1] replay_buffer.recompute_lambda_returns() #exp_results["loss"].append(loss.item()) ema_loss = 0.999 * ema_loss + 0.001 * loss.item()
self.near_td_after = self._td(*self.near_samples) self.near_td_gain = ((self.near_td_before - self.near_td_after).cpu().data.numpy()) self.near_td_gain_avg = self.near_td_gain.mean().item() self.sample_td_after = self._td(*e["sample"]) self.sample_td_gain = ((self.sample_td_before - self.sample_td_after).cpu().data.numpy()) self.sample_td_gain_avg = self.sample_td_gain.mean().item() def log(self, rs): e = inspect.currentframe().f_back.f_locals # Don't do this at home rs.append({ "td_error": self.sample_td_before.cpu().data.numpy(), "other_td_gain": self.other_td_gain, "other_td_gain_avg": self.other_td_gain_avg, "near_td_gain": self.near_td_gain, "near_td_gain_avg": self.near_td_gain_avg, "sample_td_gain": self.sample_td_gain, "sample_td_gain_avg": self.sample_td_gain_avg, "idx": e["idx"].cpu().data.numpy(), "step": e["it"], }) if __name__ == "__main__": ARGS = parser.parse_args() device = torch.device(ARGS.device) nn.set_device(device) main()
def main(args): device = torch.device(args.device) mm.set_device(device) results_conn = lmdb.open(f'{args.save_path}/run_{args.run}', map_size=int(16 * 2 ** 30)) params_conn = lmdb.open(f'{args.save_path}/run_{args.run}/params', map_size=int(16 * 2 ** 30)) with results_conn.begin(write=True) as txn: txn.put(b'args', packobj(args)) print(args) seed = args.run + 1_642_559 # A large prime number torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(args.env_name) mbsize = args.mbsize nhid = args.nhid gamma = 0.99 #env.r_gamma = gamma checkpoint_freq = args.checkpoint_freq test_freq = args.test_freq target_tau = args.target_tau target_clone_interval = args.target_clone_interval target_type = args.target_type num_iterations = args.num_iterations num_Q_outputs = env.num_actions # Model act = torch.nn.LeakyReLU() if args.body_type == 'normal': body = torch.nn.Sequential(torch.nn.Conv2d(4, nhid, 8, stride=4, padding=4), act, torch.nn.Conv2d(nhid, nhid*2, 4, stride=2,padding=2), act, torch.nn.Conv2d(nhid*2, nhid*2, 3,padding=1), act, torch.nn.Flatten(), torch.nn.Linear(nhid*2*12*12, nhid*16), act) elif args.body_type == 'tiny': body = torch.nn.Sequential(torch.nn.Conv2d(4, nhid, 3, stride=2, padding=1), act, # 42 torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1), act, # 21 torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1), act, # 11 torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1), act, # 6 torch.nn.Conv2d(nhid, num_Q_outputs, 6), # 1 torch.nn.Flatten()) if args.head_type == 'normal': head = torch.nn.Sequential(torch.nn.Linear(nhid*16, num_Q_outputs)) elif args.head_type == 'slim': # Slim end to we can do block diagonal zeta head = torch.nn.Sequential(torch.nn.Linear(nhid*16, nhid), act, torch.nn.Linear(nhid, nhid), act, torch.nn.Linear(nhid, num_Q_outputs)) elif args.head_type == 'slim2': head = torch.nn.Sequential(torch.nn.Linear(nhid*16, nhid * 2), act, torch.nn.Linear(nhid * 2, num_Q_outputs)) elif args.head_type == 'none': head = torch.nn.Sequential() Qf = torch.nn.Sequential(body, head) Qf.to(device) Qf.apply(init_weights) if args.target_type == 'none': Qf_target = Qf else: Qf_target = copy.deepcopy(Qf) opt = make_opt(args, Qf.parameters()) do_set_predictions = args.opt == 'msgd_corr' # Replay Buffer replay_buffer = ReplayBufferV2(seed, args.buffer_size) test_replay_buffer = ReplayBufferV2(seed, 10000) # Get expert trajectories expert = load_expert(args.env_name, env) fill_buffer_with_expert(expert, env, replay_buffer) fill_buffer_with_expert(expert, env, test_replay_buffer) ar = lambda x: torch.arange(x.shape[0], device=x.device) losses = [] num_iterations = 1 + num_iterations ignore_vprime = bool(args.opt_ignore_vprime) # Run policy evaluation for it in (tqdm(range(num_iterations), smoothing=0) if args.progress else range(num_iterations)): sample = replay_buffer.sample(mbsize) q = Qf(sample.s) v = q[ar(q), sample.a.long()] vp = Qf_target(sample.sp)[ar(q), sample.ap.long()] # Sarsa updat gvp = (1 - sample.t.float()) * gamma * vp loss = (v - (sample.r + gvp.detach())).pow(2) _loss = loss if do_set_predictions: opt.set_predictions(v.mean(), gvp.mean() if not ignore_vprime else None) loss = loss.mean() loss.backward(retain_graph=True) opt.step() opt.zero_grad() losses.append(loss.item()) if target_type == 'frozen' and it % target_clone_interval == 0: Qf_target = copy.deepcopy(Qf) elif target_type == 'moving': for target_param, param in zip(Qf_target.parameters(), Qf.parameters()): target_param.data.mul_(1-target_tau).add_(param, alpha=target_tau) if it % checkpoint_freq == 0 and args.save_parameters: with params_conn.begin(write=True) as txn: txn.put(f'parameters_{it}'.encode(), packobj(Qf.state_dict())) if it % test_freq == 0: expert_q_loss = 0 expert_v_loss = 0 mc_loss = 0 n = 0 with torch.no_grad(): for sample in test_replay_buffer.iterate(512): n += sample.s.shape[0] q = Qf(sample.s)[ar(sample.a), sample.a.long()] mc_loss += (q - sample.g).pow(2).sum().item() print(q.shape, sample.g.shape) with results_conn.begin(write=True) as txn: txn.put(f'expert-loss_{it}'.encode(), packobj((expert_q_loss/n, expert_v_loss/n, mc_loss/n))) if it > 0: txn.put(f'train-loss_{it}'.encode(), packobj(losses)) print(it, np.mean(losses), (expert_q_loss/n, expert_v_loss/n, mc_loss/n)) losses = [] if np.isnan(loss.item()): print("Learning has diverged, nan loss") with results_conn.begin(write=True) as txn: txn.put(b'diverged', b'True') break print("Done.")
def main(argv): results = { "episode": [], "measure": [], "parameters": [], } device = torch.device(ARGS.device) nn.set_device(device) hps = { "opt": ARGS.opt, "env_name": ARGS.env_name, "lr": ARGS.learning_rate, "weight_decay": ARGS.weight_decay, "run": ARGS.run, "mbsize": ARGS.mbsize, } start_step = 0 nhid = hps.get("nhid", 32) gamma = hps.get("gamma", 0.99) mbsize = hps.get("mbsize", 32) weight_decay = hps.get("weight_decay", 0) sample_near = hps.get("sample_near", "both") slice_size = hps.get("slice_size", 0) env_name = hps.get("env_name", "ms_pacman") clone_interval = hps.get("clone_interval", 10_000) reset_on_clone = hps.get("reset_on_clone", False) reset_opt_on_clone = hps.get("reset_opt_on_clone", False) max_clones = hps.get("max_clones", 2) target = hps.get("target", "last") # self, last, clones replay_type = hps.get("replay_type", "normal") # normal, prioritized final_epsilon = hps.get("final_epsilon", 0.05) num_exploration_steps = hps.get("num_exploration_steps", 500_000) lr = hps.get("lr", 1e-4) num_iterations = hps.get("num_iterations", 10_000_000) buffer_size = hps.get("buffer_size", ARGS.buffer_size) to_test_prob = ARGS.to_test_prob seed = hps.get("run", 0) + 1_642_559 # A large prime number hps["_seed"] = seed torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(env_name) num_act = env.num_actions # Define model _Qarch, theta_q, Qf, _Qsemi = nn.build( nn.conv2d(4, nhid, 8, stride=4), # Input is 84x84 nn.conv2d(nhid, nhid * 2, 4, stride=2), nn.conv2d(nhid * 2, nhid * 2, 3), nn.flatten(), nn.hidden(nhid * 2 * 12 * 12, nhid * 16), nn.linear(nhid * 16, num_act), ) def make_opt(): if hps.get("opt", "sgd") == "sgd": return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "msgd": return torch.optim.SGD(theta_q, lr, momentum=hps.get("beta", 0.99), weight_decay=weight_decay) elif hps["opt"] == "rmsprop": return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "adam": return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay) else: raise ValueError(hps["opt"]) opt = make_opt() clone_theta_q = lambda: [i.detach().clone() for i in theta_q] def copy_theta_q_to_target(): for i in range(len(theta_q)): frozen_theta_q[i] = theta_q[i].detach().clone() # Define loss def sl1(a, b): d = a - b u = abs(d) s = d**2 m = (u < s).float() return u * m + s * (1 - m) td = lambda s, a, r, sp, t, w, tw=theta_q: sl1( r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(), Qf(s, w)[np.arange(len(a)), a.long()], ) obs = env.reset() replay_buffer = PrioritizedExperienceReplay(seed, buffer_size, near_strategy=sample_near) test_set = {} last_lock_refresh = time.time() total_reward = 0 last_end = 0 num_fill = min(200000, replay_buffer.size) num_measure = 500 _t0 = t0 = t1 = t2 = t3 = t4 = time.time() tm0 = tm1 = tm2 = tm3 = time.time() ema_loss = 0 last_rewards = [0] measure = Measures() print("Filling buffer") if start_step < num_exploration_steps: epsilon = 1 - (start_step / num_exploration_steps) * (1 - final_epsilon) else: epsilon = final_epsilon for it in range(num_fill): if start_step == 0: action = rng.randint(0, num_act) else: if rng.uniform(0, 1) < epsilon: action = rng.randint(0, num_act) else: action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item() obsram = env.getRAM().tostring() if obsram not in test_set: test_set[obsram] = float(rng.uniform(0, 1) < to_test_prob) obsp, r, done, info = env.step(action) replay_buffer.add(obs, action, r, done, env.enumber % 2) replay_buffer.set_last_priority(1 - test_set[obsram]) obs = obsp if done: replay_buffer.set_last_priority(1 - test_set[obsram]) obs = env.reset() past_theta = [clone_theta_q()] for it in range(start_step, num_iterations): do_measure = not it % num_measure eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60 if it and it % 100_000 == 0 or it == num_iterations - 1: ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)} ps.update({"step": it}) results["parameters"].append(ps) if it % 10_000 == 0: print( it, f"{(t1 - t0)*1000:.2f}, {(t2 - t1)*1000:.2f}, {(t3 - t2)*1000:.2f}, {(t4 - t3)*1000:.2f},", f"{(tm1 - tm0)*1000:.2f}, {(tm3 - tm2)*1000:.2f},", f"{int(eta//60):2d}h{int(eta%60):02d}m left", f":: {ema_loss:.5f}, last 10 rewards: {np.mean(last_rewards):.2f}", #" " * 20, #end="\r", )
def main(): gamma = 0.99 hps = pickle.load(open(ARGS.checkpoint, 'rb'))['hps'] env_name = hps["env_name"] if 'Lambda' in hps: Lambda = hps['Lambda'] else: Lambda = 0 device = torch.device(ARGS.device) nn.set_device(device) replay_buffer = ReplayBuffer(ARGS.run, ARGS.buffer_size) Qf, theta_q = fill_buffer_with_expert(replay_buffer, env_name) for p in theta_q: p.requires_grad = True if Lambda > 0: replay_buffer.compute_episode_boundaries() replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda, gamma) td__ = lambda s, a, r, sp, t, idx, w, tw: sl1( r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(), Qf(s, w)[np.arange(len(a)), a.long()], ) td = lambda s, a, r, sp, t, idx, w, tw: Qf(s, w).max(1)[0] tdL = lambda s, a, r, sp, t, idx, w, tw: sl1( Qf(s, w)[:, 0], replay_buffer.LG[idx]) loss_func = { 'td': td, 'tdL': tdL}[ARGS.loss_func] opt = torch.optim.SGD(theta_q, 1) def grad_sim(inp, grad): dot = sum([(p.grad * gp).sum() for p, gp in zip(inp, grad)]) nA = torch.sqrt(sum([(p.grad**2).sum() for p, gp in zip(inp, grad)])) nB = torch.sqrt(sum([(gp**2).sum() for p, gp in zip(inp, grad)])) return (dot / (nA * nB)).item() relevant_features = np.int32( sorted(list(atari_dict[env_name.replace("_", "")].values()))) sims = [] ram_sims = [] for i in range(2000): sim = [] *sample, idx = replay_buffer.sample(1) loss = loss_func(*sample, idx, theta_q, theta_q).mean() loss.backward() g0 = [p.grad + 0 for p in theta_q] for j in range(-30, 31): opt.zero_grad() loss = loss_func(*replay_buffer.get(idx + j), theta_q, theta_q).mean() loss.backward() sim.append(grad_sim(theta_q, g0)) sims.append(np.float32(sim)) for j in range(200): opt.zero_grad() *sample_j, idx_j = replay_buffer.sample(1) loss = loss_func(*sample_j, idx_j, theta_q, theta_q).mean() loss.backward() ram_sims.append( (grad_sim(theta_q, g0), abs(replay_buffer.ram[idx[0]][relevant_features].float() - replay_buffer.ram[idx_j[0]][relevant_features].float()).mean())) opt.zero_grad() ram_sims = np.float32( ram_sims) #np.histogram(np.float32(ram_sim), 100, (-1, 1)) # Compute "True" gradient grads = [i.detach() * 0 for i in theta_q] N = 0 for samples in replay_buffer.in_order_iterate(ARGS.mbsize * 8): loss = loss_func(*samples, theta_q, theta_q).mean() loss.backward() N += samples[0].shape[0] for p, gp in zip(theta_q, grads): gp.data.add_(p.grad) opt.zero_grad() dots = [] i = 0 for sample in replay_buffer.in_order_iterate(1): loss = loss_func(*sample, theta_q, theta_q).mean() loss.backward() dots.append(grad_sim(theta_q, grads)) opt.zero_grad() i += 1 histo = np.histogram(dots, 100, (-1, 1)) results = { "grads": [i.cpu().data.numpy() for i in grads], "sims": np.float32(sims), "histo": histo, "ram_sims": ram_sims, } path = f'results/grads_{ARGS.checkpoint}.pkl' with open(path, "wb") as f: pickle.dump(results, f)
def main(): device = torch.device(ARGS.device) mm.set_device(device) results = { "measure": [], "parameters": [], } seed = ARGS.run + 1_642_559 # A large prime number torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(ARGS.env_name) mbsize = ARGS.mbsize Lambda = ARGS.Lambda nhid = 32 num_measure = 1000 gamma = 0.99 clone_interval = ARGS.clone_interval num_iterations = ARGS.num_iterations num_Q_outputs = 1 # Model _Qarch, theta_q, Qf, _Qsemi = mm.build( mm.conv2d(4, nhid, 8, stride=4), # Input is 84x84 mm.conv2d(nhid, nhid * 2, 4, stride=2), mm.conv2d(nhid * 2, nhid * 2, 3), mm.flatten(), mm.hidden(nhid * 2 * 12 * 12, nhid * 16), mm.linear(nhid * 16, num_Q_outputs), ) clone_theta_q = lambda: [i.detach().clone() for i in theta_q] theta_target = clone_theta_q() opt = make_opt(ARGS.opt, theta_q, ARGS.learning_rate, ARGS.weight_decay) # Replay Buffer replay_buffer = ReplayBuffer(seed, ARGS.buffer_size) # Losses td = lambda s, a, r, sp, t, idx, w, tw: sl1( r + (1 - t.float()) * gamma * Qf(sp, tw)[:, 0].detach(), Qf(s, w)[:, 0], ) tdL = lambda s, a, r, sp, t, idx, w, tw: sl1( Qf(s, w)[:, 0], replay_buffer.LG[idx]) mc = lambda s, a, r, sp, t, idx, w, tw: sl1( Qf(s, w)[:, 0], replay_buffer.g[idx]) # Define metrics measure = Measures( theta_q, { "td": lambda x, w: td(*x, w, theta_target), "tdL": lambda x, w: tdL(*x, w, theta_target), "mc": lambda x, w: mc(*x, w, theta_target), }, replay_buffer, results["measure"], 32) # Get expert trajectories rand_classes = fill_buffer_with_expert(env, replay_buffer) # Compute initial values replay_buffer.compute_values(lambda s: Qf(s, theta_q), num_Q_outputs) replay_buffer.compute_returns(gamma) replay_buffer.compute_reward_distances() replay_buffer.compute_episode_boundaries() replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda, gamma) # Run policy evaluation for it in range(num_iterations): do_measure = not it % num_measure sample = replay_buffer.sample(mbsize) if do_measure: measure.pre(sample) replay_buffer.compute_value_difference(sample, Qf(sample[0], theta_q)) loss = tdL(*sample, theta_q, theta_target) loss = loss.mean() loss.backward() opt.step() opt.zero_grad() replay_buffer.update_values(sample, Qf(sample[0], theta_q)) if do_measure: measure.post() if it and clone_interval and it % clone_interval == 0: theta_target = clone_theta_q() replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda, gamma) if it and it % clone_interval == 0 or it == num_iterations - 1: ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)} ps.update({"step": it}) results["parameters"].append(ps) with open(f'results/td_lambda_{run}.pkl', 'wb') as f: pickle.dump(results, f)
def main(args): device = torch.device(args.device) mm.set_device(device) results_conn = lmdb.open(f'{args.save_path}/run_{args.run}', map_size=int(16 * 2**30)) params_conn = lmdb.open(f'{args.save_path}/run_{args.run}/params', map_size=int(16 * 2**30)) with results_conn.begin(write=True) as txn: txn.put(b'args', packobj(args)) print(args) seed = args.run + 1_642_559 # A large prime number torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(args.env_name) mbsize = args.mbsize nhid = args.nhid gamma = 0.99 env.r_gamma = gamma checkpoint_freq = args.checkpoint_freq test_freq = args.test_freq target_tau = args.target_tau target_clone_interval = args.target_clone_interval target_type = args.target_type num_iterations = args.num_iterations num_Q_outputs = env.num_actions td_steps = args.td_n_step num_env_steps = args.num_env_steps measure_drift = args.measure_drift # Model act = { 'lrelu': torch.nn.LeakyReLU(), 'tanh': torch.nn.Tanh(), 'elu': torch.nn.ELU(), }[args.act] # Body if args.body_type == 'normal': body = torch.nn.Sequential( torch.nn.Conv2d(4, nhid, 8, stride=4, padding=4), act, torch.nn.Conv2d(nhid, nhid * 2, 4, stride=2, padding=2), act, torch.nn.Conv2d(nhid * 2, nhid * 2, 3, padding=1), act, torch.nn.Flatten(), torch.nn.Linear(nhid * 2 * 12 * 12, nhid * 16), act) elif args.body_type == 'slim_bot2': body = torch.nn.Sequential( torch.nn.Conv2d(4, nhid // 2, 8, stride=4, padding=4), act, torch.nn.Conv2d(nhid // 2, nhid, 4, stride=2, padding=2), act, torch.nn.Conv2d(nhid, nhid * 2, 3, padding=1), act, torch.nn.Flatten(), torch.nn.Linear(nhid * 2 * 12 * 12, nhid * 16), act) elif args.body_type == 'added_bot3': body = torch.nn.Sequential( torch.nn.Conv2d(4, 4, 3, padding=1), act, torch.nn.Conv2d(4, 4, 3, padding=1), act, torch.nn.Conv2d(4, 4, 3, padding=1), act, torch.nn.Conv2d(4, nhid, 8, stride=4, padding=4), act, torch.nn.Conv2d(nhid, nhid * 2, 4, stride=2, padding=2), act, torch.nn.Conv2d(nhid * 2, nhid * 2, 3, padding=1), act, torch.nn.Flatten(), torch.nn.Linear(nhid * 2 * 12 * 12, nhid * 16), act) elif args.body_type == 'added_bot3A': body = torch.nn.Sequential( torch.nn.Conv2d(4, 8, 3, padding=1), act, torch.nn.Conv2d(8, 8, 3, padding=1), act, torch.nn.Conv2d(8, 8, 3, padding=1), act, torch.nn.Conv2d(8, nhid, 8, stride=4, padding=4), act, torch.nn.Conv2d(nhid, nhid * 2, 4, stride=2, padding=2), act, torch.nn.Conv2d(nhid * 2, nhid * 2, 3, padding=1), act, torch.nn.Flatten(), torch.nn.Linear(nhid * 2 * 12 * 12, nhid * 16), act) elif args.body_type == 'tiny': body = torch.nn.Sequential( torch.nn.Conv2d(4, nhid, 3, stride=2, padding=1), act, # 42 torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1), act, # 21 torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1), act, # 11 torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1), act, # 6 torch.nn.Conv2d(nhid, num_Q_outputs, 6), # 1 torch.nn.Flatten()) # Head if args.head_type == 'normal': head = torch.nn.Sequential(torch.nn.Linear(nhid * 16, num_Q_outputs)) elif args.head_type == 'slim': # Slim end to we can do block diagonal zeta head = torch.nn.Sequential(torch.nn.Linear(nhid * 16, nhid), act, torch.nn.Linear(nhid, nhid), act, torch.nn.Linear(nhid, num_Q_outputs)) elif args.head_type == 'slim2': head = torch.nn.Sequential(torch.nn.Linear(nhid * 16, nhid * 2), act, torch.nn.Linear(nhid * 2, num_Q_outputs)) elif args.head_type == 'none': head = torch.nn.Sequential() Qf = torch.nn.Sequential(body, head) Qf.to(device) Qf.apply(init_weights) if args.target_type == 'none': Qf_target = Qf else: Qf_target = copy.deepcopy(Qf) opt = make_opt(args, Qf.parameters()) opt.epsilon = 1e-2 do_specific_backward = args.opt == 'msgd_corr' # Replay Buffer replay_buffer = ReplayBufferV2(seed, args.buffer_size) ar = lambda x: torch.arange(x.shape[0], device=x.device) losses = [] num_iterations = 1 + num_iterations ignore_vprime = bool(args.opt_ignore_vprime) total_reward = 0 last_end = 0 last_rewards = [] num_exploration_steps = 50_000 final_epsilon = 0.05 recent_states = [] recent_values = [] obs = env.reset() drift = 0 # Run policy evaluation _prof = (tqdm(range(num_iterations), smoothing=0.001) if args.progress else range(num_iterations)) for it in _prof: for eit in range(num_env_steps): if it < num_exploration_steps: epsilon = 1 - (it / num_exploration_steps) * (1 - final_epsilon) else: epsilon = final_epsilon if rng.uniform(0, 1) < epsilon: action = rng.randint(0, num_Q_outputs) else: with torch.no_grad(): action = Qf( torch.tensor(obs / 255.0, device=device).unsqueeze( 0).float()).argmax().item() obsp, r, done, info = env.step(action) total_reward += r replay_buffer.add(obs, action, r, done, env.enumber % 2) obs = obsp if done: obs = env.reset() with results_conn.begin(write=True) as txn: txn.put( f'episode_{env.enumber-1}'.encode(), packobj({ "end": it, "start": last_end, "total_reward": total_reward })) last_end = it last_rewards = [total_reward] + last_rewards[:10] if args.progress: _prof.set_description_str( f'reward {int(100*total_reward)}, ' f'{int(100*np.mean(last_rewards))}, ' f'{drift:.5f}') total_reward = 0 if replay_buffer.current_size < 5000: continue sample = replay_buffer.sample(mbsize, n_step=td_steps) if Qf_target is Qf: q = Qf(torch.cat([sample.s, sample.sp], 0)) v = q[ar(sample.s), sample.a.long()] vp = q[ar(sample.sp) + sample.s.shape[0], sample.ap.long()] else: q = Qf(sample.s) v = q[ar(q), sample.a.long()] vp = Qf_target(sample.sp)[ar(q), sample.ap.long()] gamma_mask = (1 - sample.t.float()) * (gamma**td_steps) target = sample.r + gamma_mask * vp loss = (v - target.detach()).pow(2) if do_specific_backward: opt.backward_and_step(v, vp, v - target, gamma_mask) #opt.set_predictions(v.mean(), gvp.mean() if not ignore_vprime else None) else: loss = loss.mean() loss.backward() opt.step() opt.zero_grad() losses.append(loss.item()) if measure_drift: recent_states.append((sample.sp, sample.ap.long())) recent_values.append(vp.detach()) if len(recent_states) >= 32: rs = torch.cat([i[0] for i in recent_states]) ra = torch.cat([i[1] for i in recent_states]) rvp = torch.cat(recent_values) with torch.no_grad(): nvp = Qf_target(rs)[ar(ra), ra] drift = abs(rvp - nvp).mean().item() with results_conn.begin(write=True) as txn: txn.put(f'value_drift_{it}'.encode(), packobj(drift)) recent_states = [] recent_values = [] if target_type == 'frozen' and it % target_clone_interval == 0: Qf_target = copy.deepcopy(Qf) elif target_type == 'moving': for target_param, param in zip(Qf_target.parameters(), Qf.parameters()): target_param.data.mul_(1 - target_tau).add_(param, alpha=target_tau) if it % checkpoint_freq == 0 and args.save_parameters: with params_conn.begin(write=True) as txn: txn.put(f'parameters_last'.encode(), packobj(Qf.state_dict())) if it % test_freq == 0: mc_loss = 0 n = 0 with torch.no_grad(): #print('|W|^2 =', sum([i.pow(2).sum() for i in Qf.parameters()])) #for sample in replay_buffer.iterate(512): while True: sample = replay_buffer.sample(512) n += sample.s.shape[0] q = Qf(sample.s).max(1).values mc_loss += (q - sample.g).pow(2).sum().item() if n > 10000: break with results_conn.begin(write=True) as txn: txn.put(f'mc-loss_{it}'.encode(), packobj((mc_loss / n, ))) if it > 0: txn.put(f'train-loss_{it}'.encode(), packobj(losses)) losses = [] if np.isnan(loss.item()): print("Learning has diverged, nan loss") with results_conn.begin(write=True) as txn: txn.put(b'diverged', b'True') break print("Done.")