Example #1
0
 def sample(self, n_step=1):
     if 1:
         return self.i2y(tint(
             [np.random.randint(self.tslice - 1, self.length - 1)]),
                         n_step=n_step)
     r = np.random.randint(0, 2)
     if r == 0:
         return self.i2y(tint([np.random.choice(self.ridx)]))
     return self.i2y(tint([np.random.choice(self.noridx)]))
Example #2
0
 def i2y(self, idx, n_step=1):
     if self.tslice > 1:
         sidx = (idx.reshape((idx.shape[0], 1)) + self.tslice_range)
     else:
         sidx = idx
     sidxp = torch.min(tint(self.length), sidx + n_step)
     idxp = torch.min(tint(self.length), idx + n_step)
     return (
         self.s[sidx].float() / self.snorm,
         self.a[idx],
         self.r[idx] if n_step == 1 else self.greturn(self.r[idx:idxp]),
         self.s[sidxp].float() / self.snorm,
         self.t[idxp - 1],
         self.g[idx],
         self.a[idxp],
         idx,
     )
 def slice_near(self, i, dist, exclude_0=True):
     ar = np.arange(-dist, dist + 1)
     if exclude_0:
         ar = ar[ar != 0]
     sidx = i + tint(ar)
     pmask = (sidx >= 3).float() * (sidx <= self.length - 2).float()
     sidx = sidx.clamp(3, self.length - 2)
     return (*self.i2y(sidx), pmask)
 def compute_values(self, V):
     s = self.i2y(tint(np.arange(3, self.length - 1)))[0]
     with torch.no_grad():
         v = V(s).detach()
     if not hasattr(self, 'v'):
         self.v = torch.zeros([self.length, v.shape[1]],
                              dtype=torch.float32,
                              device=self.device)
     self.v[3:self.length - 1] = v
def fill_buffer_with_expert(replay_buffer, env_name, epsilon=0.01):
  mbsize = ARGS.mbsize
  envs = [AtariEnv(env_name) for i in range(mbsize)]
  num_act = envs[0].num_actions

  nhid = 32
  _, theta_q, Qf, _ = nn.build(
      nn.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
      nn.conv2d(nhid, nhid * 2, 4, stride=2),
      nn.conv2d(nhid * 2, nhid * 2, 3),
      nn.flatten(),
      nn.hidden(nhid * 2 * 12 * 12, nhid * 16),
      nn.linear(nhid * 16, num_act),
  )

  theta_q_trained = load_parameters_from_checkpoint()
  if ARGS.expert_is_self:
    theta_expert = theta_q_trained
  else:
    expert_id = {
        'ms_pacman':457, 'asterix':403, 'seaquest':428}[env_name]
    with open(f'checkpoints/dqn_model_{expert_id}.pkl',
              "rb") as f:
      theta_expert = pickle.load(f)
    theta_expert = [tf(i) for i in theta_expert]

  obs = [i.reset() for i in envs]
  trajs = [list() for i in range(mbsize)]
  enumbers = list(range(mbsize))
  replay_buffer.ram = torch.zeros([replay_buffer.size, 128],
                                  dtype=torch.uint8,
                                  device=replay_buffer.device)

  while True:
    mbobs = tf(obs) / 255
    greedy_actions = Qf(mbobs, theta_expert).argmax(1)
    random_actions = np.random.randint(0, num_act, mbsize)
    actions = [
        j if np.random.random() < epsilon else i
        for i, j in zip(greedy_actions, random_actions)
    ]
    for i, (e, a) in enumerate(zip(envs, actions)):
      obsp, r, done, _ = e.step(a)
      trajs[i].append([obs[i], int(a), float(r), int(done), e.getRAM() + 0])
      obs[i] = obsp
        if replay_buffer.idx + len(trajs[i]) + 4 >= replay_buffer.size:
          # We're done!
          return Qf, theta_q_trained
        replay_buffer.new_episode(trajs[i][0][0], enumbers[i] % 2)
        for s, a, r, d, ram in trajs[i]:
          replay_buffer.ram[replay_buffer.idx] = tint(ram)
          replay_buffer.add(s, a, r, d, enumbers[i] % 2)

        trajs[i] = []
        obs[i] = envs[i].reset()
        enumbers[i] = max(enumbers) + 1
 def slice_near(self, idx, dist=10, exclude_0=True):
     ar = np.arange(-dist, dist + 1)
     if exclude_0:
         ar = ar[ar != 0]
     sidx = (idx.reshape((-1, 1)) + tint(ar))
     p = self.p[idx]
     ps = self.p[sidx]
     pmask = (p[:, None] == ps).reshape((-1, )).float()
     sidx = sidx.reshape((-1, ))  # clamp??
     return self._idx2xy(sidx), pmask
Example #7
0
 def __init__(self, gamma, tslice=4, snorm=255):
     self.s = []
     self.a = []
     self.r = []
     self.t = []
     self.ram = []
     self.gamma = gamma
     self.is_over = False
     self.device = get_device()
     self.tslice = tslice
     self.tslice_range = tint(np.arange(tslice) - tslice + 1)
     self.snorm = snorm
 def in_order_iterate(self, mbsize, until=None):
     if until is None:
         until = self.size
     valid_indices = np.arange(self.size)[self._sumtree.levels[-1] > 0]
     it = 0
     end = 0
     while end < valid_indices.shape[0]:
         end = min(it + mbsize, valid_indices.shape[0])
         if end > until:
             break
         yield self.get(tint(valid_indices[it:end]))
         it += mbsize
 def _idx2xy(self, idx, sidx=None):
     if sidx is None:
         d = tint((-3, -2, -1, 0, 1))  # 4 state history slice + 1 for s'
         sidx = (idx.reshape((idx.shape[0], 1)) + d) % self.maxidx
     return (
         self.s[sidx[:, :4]].float() / 255,
         self.a[idx],
         self.r[idx],
         self.s[sidx[:, 1:]].float() / 255,
         self.t[idx],
         idx,
     )
 def compute_lambda_returns(self, fun, Lambda, gamma):
     if not hasattr(self, 'LR'):
         self.LR = LambdaReturn(Lambda, gamma)
         self.LG = torch.zeros([self.size],
                               dtype=torch.float32,
                               device=self.device)
     i = 0
     for start, end in self.episodes:
         s = self._idx2xy(tint(np.arange(start + 1, end)))[0]
         with torch.no_grad():
             vp = fun(s)[:, 0].detach()
             vp = torch.cat([vp, torch.zeros((1, ), device=self.device)])
         self.LG[start:end] = self.LR(self.r[start:end], vp)
         i += 1
    def compute_values(self, fun, num_act, mbsize=128, nbins=256):
        if not hasattr(self, "last_v"):
            self.last_v = torch.zeros([self.size, num_act],
                                      dtype=torch.float32,
                                      device=self.device)
            self.vdiff_acc = np.zeros(nbins)
            self.vdiff_cnt = np.zeros(nbins)
            self.vdiff_bins = np.linspace(-1, 1, nbins - 1)

        d = tint((-3, -2, -1, 0))  # 4 state history slice
        idx_0 = tint(np.arange(mbsize)) + 3
        idx_s = idx_0.reshape((-1, 1)) + d
        for i in range(int(np.ceil((self.maxidx - 4) / mbsize))):
            islice = idx_s + i * mbsize
            iar = idx_0 + i * mbsize
            if (i + 1) * mbsize >= self.maxidx - 2:
                islice = islice[:self.maxidx - i * mbsize - 2]
                iar = iar[:self.maxidx - i * mbsize - 2]
            s = self.s[islice].float().div_(255)
            with torch.no_grad():
                self.last_v[iar] = fun(s)
            if not i % 100:
                gc.collect()
 def i2y(self, idx):
     d = tint((-3, -2, -1, 0, 1))  # 4 state history slice + 1 for s'
     sidx = (idx.reshape((idx.shape[0], 1)) + d)
     return (
         self.s[sidx[:, :4]].float() / 255,
         self.a[idx],
         self.r[idx],
         self.s[sidx[:, 1:]].float() / 255,
         self.t[idx],
         self.g[idx],
         self.lg[idx],
         self.a[idx + 1],
         idx,
     )
Example #13
0
 def iterate(self, mbsize):
     eidx = 0
     t = 0
     z = torch.zeros(mbsize, device=self.device)
     while True:
         mb = []
         while len(mb) < mbsize:
             mb.append(self.episodes[eidx].i2y(tint([t])))
             t += 1
             if t >= self.episodes[eidx].length:
                 t = 0
                 eidx += 1
                 if eidx >= len(self.episodes):
                     break
         if len(mb):
             yield Minibatch(
                 *[
                     torch.cat([d[i] for d in mb])
                     for i in range(len(mb[0]))
                 ], z[:len(mb)])
         if eidx >= len(self.episodes) or len(mb) < mbsize:
             break
 def rand_acc(s, a, r, sp, t, idx, w, tw):
   return (Qf(s, w).argmax(1) != tint(rand_classes[idx])).float()
 def rand_nll(s, a, r, sp, t, idx, w, tw):
   return F.cross_entropy(Qf(s, w), tint(rand_classes[idx]), reduce=False)
 def sample(self):
     return self.i2y(tint([np.random.randint(3, self.length - 1)]))
 def stratified_sample(self, n):
     # As per Schaul et al. (2015)
     return tint([
         self.sample((i + q) / n)
         for i, q in enumerate(self.rng.uniform(0, 1, n))
     ])
 def compute_lambda_return(self, V, lr):
     s = self.i2y(tint(np.arange(4, self.length)))[0]
     with torch.no_grad():
         vp = V(s).detach()
         self.lg[3:self.length - 1] = lr(self.r[3:self.length - 1], vp)
Example #19
0
def main():
    results = {
        "episode": [],
        "measure": [],
        "parameters": [],
    }

    hps = {
        "opt": ARGS.opt,
        "env_name": ARGS.env_name,
        "lr": ARGS.learning_rate,
        "weight_decay": ARGS.weight_decay,
        "run": ARGS.run,
    }
    start_step = ARGS.start_step
    nhid = hps.get("nhid", 32)
    gamma = hps.get("gamma", 0.99)
    mbsize = hps.get("mbsize", 32)
    weight_decay = hps.get("weight_decay", 0)
    sample_near = hps.get("sample_near", "both")
    slice_size = hps.get("slice_size", 0)
    env_name = hps.get("env_name", "ms_pacman")

    clone_interval = hps.get("clone_interval", 10_000)
    reset_on_clone = hps.get("reset_on_clone", False)
    reset_opt_on_clone = hps.get("reset_opt_on_clone", False)
    max_clones = hps.get("max_clones", 2)
    replay_type = hps.get("replay_type", "normal")  # normal, prioritized
    final_epsilon = hps.get("final_epsilon", 0.05)
    num_exploration_steps = hps.get("num_exploration_steps", 500_000)

    lr = hps.get("lr", 1e-4)
    num_iterations = hps.get("num_iterations", 10_000_000)
    buffer_size = hps.get("buffer_size", 250_000)

    seed = hps.get("run", 0) + 1_642_559  # A large prime number
    hps["_seed"] = seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    rng = np.random.RandomState(seed)

    env = AtariEnv(env_name)
    num_act = env.num_actions

    # Define model

    _Qarch, theta_q, Qf, _Qsemi = nn.build(
        nn.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
        nn.conv2d(nhid, nhid * 2, 4, stride=2),
        nn.conv2d(nhid * 2, nhid * 2, 3),
        nn.flatten(),
        nn.hidden(nhid * 2 * 12 * 12, nhid * 16),
        nn.linear(nhid * 16, num_act),
    )

    def make_opt():
        if hps.get("opt", "sgd") == "sgd":
            return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "msgd":
            return torch.optim.SGD(theta_q,
                                   lr,
                                   momentum=hps.get("beta", 0.99),
                                   weight_decay=weight_decay)
        elif hps["opt"] == "rmsprop":
            return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "adam":
            return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay)
        else:
            raise ValueError(hps["opt"])

    opt = make_opt()
    clone_theta_q = lambda: [i.detach().clone() for i in theta_q]

    def copy_theta_q_to_target():
        for i in range(len(theta_q)):
            frozen_theta_q[i] = theta_q[i].detach().clone()

    # Define loss
    def sl1(a, b):
        d = a - b
        u = abs(d)
        s = d**2
        m = (u < s).float()
        return u * m + s * (1 - m)

    td = lambda s, a, r, sp, t, w, tw=theta_q: sl1(
        r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(),
        Qf(s, w)[np.arange(len(a)), a.long()],
    )

    obs = env.reset()

    if replay_type == "normal":
        replay_buffer = ReplayBuffer(seed,
                                     buffer_size,
                                     near_strategy=sample_near)
    elif replay_type == "prioritized":
        replay_buffer = PrioritizedExperienceReplay(seed,
                                                    buffer_size,
                                                    near_strategy=sample_near)

    total_reward = 0
    last_end = 0
    num_fill = 200000
    num_measure = 500
    _t0 = t0 = t1 = t2 = t3 = t4 = time.time()
    tm0 = tm1 = tm2 = tm3 = time.time()
    ema_loss = 0
    last_rewards = [0]

    measure = Measures()
    print("Filling buffer")

    if start_step < num_exploration_steps:
        epsilon = 1 - (start_step / num_exploration_steps) * (1 -
                                                              final_epsilon)
    else:
        epsilon = final_epsilon

    for it in range(num_fill):
        if start_step == 0:
            action = rng.randint(0, num_act)
        else:
            if rng.uniform(0, 1) < epsilon:
                action = rng.randint(0, num_act)
            else:
                action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item()
        obsp, r, done, info = env.step(action)
        replay_buffer.add(obs, action, r, done, env.enumber % 2)
        if replay_type == "prioritized":
            replay_buffer.set_last_priority(
                td(
                    tf(obs / 255.0).unsqueeze(0),
                    tint([action]),
                    r,
                    tf(obsp / 255.0).unsqueeze(0),
                    tf([done]),
                    theta_q,
                    theta_q,
                ))

        obs = obsp
        if done:
            obs = env.reset()

    past_theta = [clone_theta_q()]

    for it in range(start_step, num_iterations):
        do_measure = not it % num_measure
        eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60
        if it and it % 100_000 == 0 or it == num_iterations - 1:
            ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)}
            ps.update({"step": it})
            results["parameters"].append(ps)

        if it % 10_000 == 0:
            print(
                it,
                f"{(t1 - t0)*1000:.2f}, {(t2 - t1)*1000:.2f}, {(t3 - t2)*1000:.2f}, {(t4 - t3)*1000:.2f},",
                f"{(tm1 - tm0)*1000:.2f}, {(tm3 - tm2)*1000:.2f},",
                f"{int(eta//60):2d}h{int(eta%60):02d}m left",
                f":: {ema_loss:.5f}, last 10 rewards: {np.mean(last_rewards):.2f}",
            )
Example #20
0
        else:
            epsilon = final_epsilon

        if rng.uniform(0, 1) < epsilon:
            action = rng.randint(0, num_act)
        else:
            action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item()
        t1 = time.time()
        obsp, r, done, info = env.step(action)
        total_reward += r
        replay_buffer.add(obs, action, r, done, env.enumber % 2)
        if replay_type == "prioritized":
            replay_buffer.set_last_priority(
                td(
                    tf(obs / 255.0).unsqueeze(0),
                    tint([action]),
                    r,
                    tf(obsp / 255.0).unsqueeze(0),
                    tf([done]),
                    theta_q,
                    theta_q,
                ))

        obs = obsp
        if done:
            obs = env.reset()
            results["episode"].append({
                "end": it,
                "start": last_end,
                "total_reward": total_reward
            })