def sample(self, n_step=1): if 1: return self.i2y(tint( [np.random.randint(self.tslice - 1, self.length - 1)]), n_step=n_step) r = np.random.randint(0, 2) if r == 0: return self.i2y(tint([np.random.choice(self.ridx)])) return self.i2y(tint([np.random.choice(self.noridx)]))
def i2y(self, idx, n_step=1): if self.tslice > 1: sidx = (idx.reshape((idx.shape[0], 1)) + self.tslice_range) else: sidx = idx sidxp = torch.min(tint(self.length), sidx + n_step) idxp = torch.min(tint(self.length), idx + n_step) return ( self.s[sidx].float() / self.snorm, self.a[idx], self.r[idx] if n_step == 1 else self.greturn(self.r[idx:idxp]), self.s[sidxp].float() / self.snorm, self.t[idxp - 1], self.g[idx], self.a[idxp], idx, )
def slice_near(self, i, dist, exclude_0=True): ar = np.arange(-dist, dist + 1) if exclude_0: ar = ar[ar != 0] sidx = i + tint(ar) pmask = (sidx >= 3).float() * (sidx <= self.length - 2).float() sidx = sidx.clamp(3, self.length - 2) return (*self.i2y(sidx), pmask)
def compute_values(self, V): s = self.i2y(tint(np.arange(3, self.length - 1)))[0] with torch.no_grad(): v = V(s).detach() if not hasattr(self, 'v'): self.v = torch.zeros([self.length, v.shape[1]], dtype=torch.float32, device=self.device) self.v[3:self.length - 1] = v
def fill_buffer_with_expert(replay_buffer, env_name, epsilon=0.01): mbsize = ARGS.mbsize envs = [AtariEnv(env_name) for i in range(mbsize)] num_act = envs[0].num_actions nhid = 32 _, theta_q, Qf, _ = nn.build( nn.conv2d(4, nhid, 8, stride=4), # Input is 84x84 nn.conv2d(nhid, nhid * 2, 4, stride=2), nn.conv2d(nhid * 2, nhid * 2, 3), nn.flatten(), nn.hidden(nhid * 2 * 12 * 12, nhid * 16), nn.linear(nhid * 16, num_act), ) theta_q_trained = load_parameters_from_checkpoint() if ARGS.expert_is_self: theta_expert = theta_q_trained else: expert_id = { 'ms_pacman':457, 'asterix':403, 'seaquest':428}[env_name] with open(f'checkpoints/dqn_model_{expert_id}.pkl', "rb") as f: theta_expert = pickle.load(f) theta_expert = [tf(i) for i in theta_expert] obs = [i.reset() for i in envs] trajs = [list() for i in range(mbsize)] enumbers = list(range(mbsize)) replay_buffer.ram = torch.zeros([replay_buffer.size, 128], dtype=torch.uint8, device=replay_buffer.device) while True: mbobs = tf(obs) / 255 greedy_actions = Qf(mbobs, theta_expert).argmax(1) random_actions = np.random.randint(0, num_act, mbsize) actions = [ j if np.random.random() < epsilon else i for i, j in zip(greedy_actions, random_actions) ] for i, (e, a) in enumerate(zip(envs, actions)): obsp, r, done, _ = e.step(a) trajs[i].append([obs[i], int(a), float(r), int(done), e.getRAM() + 0]) obs[i] = obsp if replay_buffer.idx + len(trajs[i]) + 4 >= replay_buffer.size: # We're done! return Qf, theta_q_trained replay_buffer.new_episode(trajs[i][0][0], enumbers[i] % 2) for s, a, r, d, ram in trajs[i]: replay_buffer.ram[replay_buffer.idx] = tint(ram) replay_buffer.add(s, a, r, d, enumbers[i] % 2) trajs[i] = [] obs[i] = envs[i].reset() enumbers[i] = max(enumbers) + 1
def slice_near(self, idx, dist=10, exclude_0=True): ar = np.arange(-dist, dist + 1) if exclude_0: ar = ar[ar != 0] sidx = (idx.reshape((-1, 1)) + tint(ar)) p = self.p[idx] ps = self.p[sidx] pmask = (p[:, None] == ps).reshape((-1, )).float() sidx = sidx.reshape((-1, )) # clamp?? return self._idx2xy(sidx), pmask
def __init__(self, gamma, tslice=4, snorm=255): self.s = [] self.a = [] self.r = [] self.t = [] self.ram = [] self.gamma = gamma self.is_over = False self.device = get_device() self.tslice = tslice self.tslice_range = tint(np.arange(tslice) - tslice + 1) self.snorm = snorm
def in_order_iterate(self, mbsize, until=None): if until is None: until = self.size valid_indices = np.arange(self.size)[self._sumtree.levels[-1] > 0] it = 0 end = 0 while end < valid_indices.shape[0]: end = min(it + mbsize, valid_indices.shape[0]) if end > until: break yield self.get(tint(valid_indices[it:end])) it += mbsize
def _idx2xy(self, idx, sidx=None): if sidx is None: d = tint((-3, -2, -1, 0, 1)) # 4 state history slice + 1 for s' sidx = (idx.reshape((idx.shape[0], 1)) + d) % self.maxidx return ( self.s[sidx[:, :4]].float() / 255, self.a[idx], self.r[idx], self.s[sidx[:, 1:]].float() / 255, self.t[idx], idx, )
def compute_lambda_returns(self, fun, Lambda, gamma): if not hasattr(self, 'LR'): self.LR = LambdaReturn(Lambda, gamma) self.LG = torch.zeros([self.size], dtype=torch.float32, device=self.device) i = 0 for start, end in self.episodes: s = self._idx2xy(tint(np.arange(start + 1, end)))[0] with torch.no_grad(): vp = fun(s)[:, 0].detach() vp = torch.cat([vp, torch.zeros((1, ), device=self.device)]) self.LG[start:end] = self.LR(self.r[start:end], vp) i += 1
def compute_values(self, fun, num_act, mbsize=128, nbins=256): if not hasattr(self, "last_v"): self.last_v = torch.zeros([self.size, num_act], dtype=torch.float32, device=self.device) self.vdiff_acc = np.zeros(nbins) self.vdiff_cnt = np.zeros(nbins) self.vdiff_bins = np.linspace(-1, 1, nbins - 1) d = tint((-3, -2, -1, 0)) # 4 state history slice idx_0 = tint(np.arange(mbsize)) + 3 idx_s = idx_0.reshape((-1, 1)) + d for i in range(int(np.ceil((self.maxidx - 4) / mbsize))): islice = idx_s + i * mbsize iar = idx_0 + i * mbsize if (i + 1) * mbsize >= self.maxidx - 2: islice = islice[:self.maxidx - i * mbsize - 2] iar = iar[:self.maxidx - i * mbsize - 2] s = self.s[islice].float().div_(255) with torch.no_grad(): self.last_v[iar] = fun(s) if not i % 100: gc.collect()
def i2y(self, idx): d = tint((-3, -2, -1, 0, 1)) # 4 state history slice + 1 for s' sidx = (idx.reshape((idx.shape[0], 1)) + d) return ( self.s[sidx[:, :4]].float() / 255, self.a[idx], self.r[idx], self.s[sidx[:, 1:]].float() / 255, self.t[idx], self.g[idx], self.lg[idx], self.a[idx + 1], idx, )
def iterate(self, mbsize): eidx = 0 t = 0 z = torch.zeros(mbsize, device=self.device) while True: mb = [] while len(mb) < mbsize: mb.append(self.episodes[eidx].i2y(tint([t]))) t += 1 if t >= self.episodes[eidx].length: t = 0 eidx += 1 if eidx >= len(self.episodes): break if len(mb): yield Minibatch( *[ torch.cat([d[i] for d in mb]) for i in range(len(mb[0])) ], z[:len(mb)]) if eidx >= len(self.episodes) or len(mb) < mbsize: break
def rand_acc(s, a, r, sp, t, idx, w, tw): return (Qf(s, w).argmax(1) != tint(rand_classes[idx])).float()
def rand_nll(s, a, r, sp, t, idx, w, tw): return F.cross_entropy(Qf(s, w), tint(rand_classes[idx]), reduce=False)
def sample(self): return self.i2y(tint([np.random.randint(3, self.length - 1)]))
def stratified_sample(self, n): # As per Schaul et al. (2015) return tint([ self.sample((i + q) / n) for i, q in enumerate(self.rng.uniform(0, 1, n)) ])
def compute_lambda_return(self, V, lr): s = self.i2y(tint(np.arange(4, self.length)))[0] with torch.no_grad(): vp = V(s).detach() self.lg[3:self.length - 1] = lr(self.r[3:self.length - 1], vp)
def main(): results = { "episode": [], "measure": [], "parameters": [], } hps = { "opt": ARGS.opt, "env_name": ARGS.env_name, "lr": ARGS.learning_rate, "weight_decay": ARGS.weight_decay, "run": ARGS.run, } start_step = ARGS.start_step nhid = hps.get("nhid", 32) gamma = hps.get("gamma", 0.99) mbsize = hps.get("mbsize", 32) weight_decay = hps.get("weight_decay", 0) sample_near = hps.get("sample_near", "both") slice_size = hps.get("slice_size", 0) env_name = hps.get("env_name", "ms_pacman") clone_interval = hps.get("clone_interval", 10_000) reset_on_clone = hps.get("reset_on_clone", False) reset_opt_on_clone = hps.get("reset_opt_on_clone", False) max_clones = hps.get("max_clones", 2) replay_type = hps.get("replay_type", "normal") # normal, prioritized final_epsilon = hps.get("final_epsilon", 0.05) num_exploration_steps = hps.get("num_exploration_steps", 500_000) lr = hps.get("lr", 1e-4) num_iterations = hps.get("num_iterations", 10_000_000) buffer_size = hps.get("buffer_size", 250_000) seed = hps.get("run", 0) + 1_642_559 # A large prime number hps["_seed"] = seed torch.manual_seed(seed) np.random.seed(seed) rng = np.random.RandomState(seed) env = AtariEnv(env_name) num_act = env.num_actions # Define model _Qarch, theta_q, Qf, _Qsemi = nn.build( nn.conv2d(4, nhid, 8, stride=4), # Input is 84x84 nn.conv2d(nhid, nhid * 2, 4, stride=2), nn.conv2d(nhid * 2, nhid * 2, 3), nn.flatten(), nn.hidden(nhid * 2 * 12 * 12, nhid * 16), nn.linear(nhid * 16, num_act), ) def make_opt(): if hps.get("opt", "sgd") == "sgd": return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "msgd": return torch.optim.SGD(theta_q, lr, momentum=hps.get("beta", 0.99), weight_decay=weight_decay) elif hps["opt"] == "rmsprop": return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay) elif hps["opt"] == "adam": return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay) else: raise ValueError(hps["opt"]) opt = make_opt() clone_theta_q = lambda: [i.detach().clone() for i in theta_q] def copy_theta_q_to_target(): for i in range(len(theta_q)): frozen_theta_q[i] = theta_q[i].detach().clone() # Define loss def sl1(a, b): d = a - b u = abs(d) s = d**2 m = (u < s).float() return u * m + s * (1 - m) td = lambda s, a, r, sp, t, w, tw=theta_q: sl1( r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(), Qf(s, w)[np.arange(len(a)), a.long()], ) obs = env.reset() if replay_type == "normal": replay_buffer = ReplayBuffer(seed, buffer_size, near_strategy=sample_near) elif replay_type == "prioritized": replay_buffer = PrioritizedExperienceReplay(seed, buffer_size, near_strategy=sample_near) total_reward = 0 last_end = 0 num_fill = 200000 num_measure = 500 _t0 = t0 = t1 = t2 = t3 = t4 = time.time() tm0 = tm1 = tm2 = tm3 = time.time() ema_loss = 0 last_rewards = [0] measure = Measures() print("Filling buffer") if start_step < num_exploration_steps: epsilon = 1 - (start_step / num_exploration_steps) * (1 - final_epsilon) else: epsilon = final_epsilon for it in range(num_fill): if start_step == 0: action = rng.randint(0, num_act) else: if rng.uniform(0, 1) < epsilon: action = rng.randint(0, num_act) else: action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item() obsp, r, done, info = env.step(action) replay_buffer.add(obs, action, r, done, env.enumber % 2) if replay_type == "prioritized": replay_buffer.set_last_priority( td( tf(obs / 255.0).unsqueeze(0), tint([action]), r, tf(obsp / 255.0).unsqueeze(0), tf([done]), theta_q, theta_q, )) obs = obsp if done: obs = env.reset() past_theta = [clone_theta_q()] for it in range(start_step, num_iterations): do_measure = not it % num_measure eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60 if it and it % 100_000 == 0 or it == num_iterations - 1: ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)} ps.update({"step": it}) results["parameters"].append(ps) if it % 10_000 == 0: print( it, f"{(t1 - t0)*1000:.2f}, {(t2 - t1)*1000:.2f}, {(t3 - t2)*1000:.2f}, {(t4 - t3)*1000:.2f},", f"{(tm1 - tm0)*1000:.2f}, {(tm3 - tm2)*1000:.2f},", f"{int(eta//60):2d}h{int(eta%60):02d}m left", f":: {ema_loss:.5f}, last 10 rewards: {np.mean(last_rewards):.2f}", )
else: epsilon = final_epsilon if rng.uniform(0, 1) < epsilon: action = rng.randint(0, num_act) else: action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item() t1 = time.time() obsp, r, done, info = env.step(action) total_reward += r replay_buffer.add(obs, action, r, done, env.enumber % 2) if replay_type == "prioritized": replay_buffer.set_last_priority( td( tf(obs / 255.0).unsqueeze(0), tint([action]), r, tf(obsp / 255.0).unsqueeze(0), tf([done]), theta_q, theta_q, )) obs = obsp if done: obs = env.reset() results["episode"].append({ "end": it, "start": last_end, "total_reward": total_reward })