def __init__(self, Lambda, gamma, max_len=4096):
     self.Lambda = Lambda
     self.gamma = gamma
     self.max_len = max_len
     self.Lt = tf(np.zeros((max_len, max_len), dtype="float32"))
     self.gt = tf(np.zeros((max_len, max_len), dtype="float32"))
     for i in range(max_len):
         self.gt[i, i:] = gamma**torch.arange(max_len - i)
         self.Lt[i, i:] = Lambda**torch.arange(max_len - i)
def fill_buffer_with_expert(replay_buffer, env_name, epsilon=0.01):
  mbsize = ARGS.mbsize
  envs = [AtariEnv(env_name) for i in range(mbsize)]
  num_act = envs[0].num_actions

  nhid = 32
  _, theta_q, Qf, _ = nn.build(
      nn.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
      nn.conv2d(nhid, nhid * 2, 4, stride=2),
      nn.conv2d(nhid * 2, nhid * 2, 3),
      nn.flatten(),
      nn.hidden(nhid * 2 * 12 * 12, nhid * 16),
      nn.linear(nhid * 16, num_act),
  )

  theta_q_trained = load_parameters_from_checkpoint()
  if ARGS.expert_is_self:
    theta_expert = theta_q_trained
  else:
    expert_id = {
        'ms_pacman':457, 'asterix':403, 'seaquest':428}[env_name]
    with open(f'checkpoints/dqn_model_{expert_id}.pkl',
              "rb") as f:
      theta_expert = pickle.load(f)
    theta_expert = [tf(i) for i in theta_expert]

  obs = [i.reset() for i in envs]
  trajs = [list() for i in range(mbsize)]
  enumbers = list(range(mbsize))
  replay_buffer.ram = torch.zeros([replay_buffer.size, 128],
                                  dtype=torch.uint8,
                                  device=replay_buffer.device)

  while True:
    mbobs = tf(obs) / 255
    greedy_actions = Qf(mbobs, theta_expert).argmax(1)
    random_actions = np.random.randint(0, num_act, mbsize)
    actions = [
        j if np.random.random() < epsilon else i
        for i, j in zip(greedy_actions, random_actions)
    ]
    for i, (e, a) in enumerate(zip(envs, actions)):
      obsp, r, done, _ = e.step(a)
      trajs[i].append([obs[i], int(a), float(r), int(done), e.getRAM() + 0])
      obs[i] = obsp
        if replay_buffer.idx + len(trajs[i]) + 4 >= replay_buffer.size:
          # We're done!
          return Qf, theta_q_trained
        replay_buffer.new_episode(trajs[i][0][0], enumbers[i] % 2)
        for s, a, r, d, ram in trajs[i]:
          replay_buffer.ram[replay_buffer.idx] = tint(ram)
          replay_buffer.add(s, a, r, d, enumbers[i] % 2)

        trajs[i] = []
        obs[i] = envs[i].reset()
        enumbers[i] = max(enumbers) + 1
 def __call__(self, r, v_sp):
     T = r.shape[0]
     if T > self.max_len:  # This almost never occurs, tbf
         t = self.max_len
         return torch.cat([self(r[:t], v_sp[:t]), self(r[t:], v_sp[t:])])
     alive = np.ones(T)
     alive[-1] = 0
     alive = tf(alive)
     n_step_rewards = torch.cumsum(r[None, :] * self.gt[:T, :T], 1)
     n_step_returns = alive[None, :] * (
         n_step_rewards + self.gamma * v_sp[None, :] * self.gt[:T, :T])
     weighted_n_step_returns = (1 - self.Lambda) * (n_step_returns *
                                                    self.Lt[:T, :T]).sum(1)
     weighted_mc_returns = self.Lt[:T, T] * self.Lambda * n_step_rewards[:,
                                                                         -1]
     lambda_target = weighted_n_step_returns + weighted_mc_returns
     return lambda_target
Exemplo n.º 4
0
def main():
  results = {
      "results": [],
      "measure_reg": [],
      "measure_td": [],
      "measure_mc": [],
  }

  hps = {
      "opt": ARGS.opt,
      "env_name": ARGS.env_name,
      "lr": ARGS.learning_rate,
      "weight_decay": ARGS.weight_decay,
      "run": ARGS.run,
  }
  nhid = hps.get("nhid", 32)
  gamma = hps.get("gamma", 0.99)
  mbsize = hps.get("mbsize", 32)
  weight_decay = hps.get("weight_decay", 0)
  sample_near = hps.get("sample_near", "both")
  slice_size = hps.get("slice_size", 0)
  env_name = hps.get("env_name", "ms_pacman")

  clone_interval = hps.get("clone_interval", 10_000)
  reset_on_clone = hps.get("reset_on_clone", False)
  reset_opt_on_clone = hps.get("reset_opt_on_clone", False)
  max_clones = hps.get("max_clones", 2)
  target = hps.get("target", "last")  # self, last, clones
  replay_type = hps.get("replay_type", "normal")  # normal, prioritized
  final_epsilon = hps.get("final_epsilon", 0.05)
  num_exploration_steps = hps.get("num_exploration_steps", 500_000)

  lr = hps.get("lr", 1e-4)
  num_iterations = hps.get("num_iterations", 10_000_000)
  buffer_size = hps.get("buffer_size", 250_000)

  seed = hps.get("run", 0) + 1_642_559  # A large prime number
  hps["_seed"] = seed
  torch.manual_seed(seed)
  np.random.seed(seed)
  rng = np.random.RandomState(seed)

  env = AtariEnv(env_name)
  num_act = env.num_actions

  def make_opt(theta):
    if hps.get("opt", "sgd") == "sgd":
      return torch.optim.SGD(theta, lr, weight_decay=weight_decay)
    elif hps["opt"] == "msgd":
      return torch.optim.SGD(
          theta, lr, momentum=hps.get("beta", 0.99), weight_decay=weight_decay)
    elif hps["opt"] == "rmsprop":
      return torch.optim.RMSprop(theta, lr, weight_decay=weight_decay)
    elif hps["opt"] == "adam":
      return torch.optim.Adam(theta, lr, weight_decay=weight_decay)
    else:
      raise ValueError(hps["opt"])

  # Define model

  _Qarch, theta_q, Qf, _Qsemi = nn.build(
      nn.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
      nn.conv2d(nhid, nhid * 2, 4, stride=2),
      nn.conv2d(nhid * 2, nhid * 2, 3),
      nn.flatten(),
      nn.hidden(nhid * 2 * 12 * 12, nhid * 16),
      nn.linear(nhid * 16, num_act),
  )
  clone_theta_q = lambda: [i.detach().clone().requires_grad_() for i in theta_q]

  # Pretrained parameters
  theta_target = load_parameters_from_checkpoint()
  # (Same) Random parameters
  theta_regress = clone_theta_q()
  theta_qlearn = clone_theta_q()
  theta_mc = clone_theta_q()
  opt_regress = make_opt(theta_regress)
  opt_qlearn = make_opt(theta_qlearn)
  opt_mc = make_opt(theta_mc)

  # Define loss
  def sl1(a, b):
    d = a - b
    u = abs(d)
    s = d**2
    m = (u < s).float()
    return u * m + s * (1 - m)

  td = lambda s, a, r, sp, t, w, tw=theta_q: sl1(
      r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(),
      Qf(s, w)[np.arange(len(a)), a.long()],
  )

  obs = env.reset()

  replay_buffer = ReplayBuffer(seed, buffer_size, near_strategy=sample_near)

  total_reward = 0
  last_end = 0
  num_fill = buffer_size
  num_measure = 500
  _t0 = t0 = t1 = t2 = t3 = t4 = time.time()
  tm0 = tm1 = tm2 = tm3 = time.time()
  ema_loss = 0
  last_rewards = [0]

  print("Filling buffer")
  epsilon = final_epsilon

  replay_buffer.new_episode(obs, env.enumber % 2)
  while replay_buffer.idx < replay_buffer.size - 10:
    if rng.uniform(0, 1) < epsilon:
      action = rng.randint(0, num_act)
    else:
      action = Qf(tf(obs / 255.0).unsqueeze(0), theta_target).argmax().item()
    obsp, r, done, info = env.step(action)
    replay_buffer.add(obs, action, r, done, env.enumber % 2)

    obs = obsp
    if done:
      obs = env.reset()
      replay_buffer.new_episode(obs, env.enumber % 2)
  # Remove last episode from replay buffer, as it didn't end
  it = replay_buffer.idx
  curp = replay_buffer.p[it]
  while replay_buffer.p[it] == curp:
    replay_buffer._sumtree.set(it, 0)
    it -= 1
  print(f'went from {replay_buffer.idx} to {it} when deleting states')

  print("Computing returns")
  replay_buffer.compute_values(lambda s: Qf(s, theta_regress), num_act)
  replay_buffer.compute_returns(gamma)
  replay_buffer.compute_reward_distances()
  print("Training regressions")
  losses_reg, losses_td, losses_mc = [], [], []

  loss_reg_f = lambda x, w: sl1(Qf(x[0], w), Qf(x[0], theta_target))
  loss_td_f = lambda x, w: td(*x[:-1], w, theta_target)
  loss_mc_f = lambda x, w: sl1(
      Qf(x[0], w)[np.arange(len(x[1])), x[1].long()], replay_buffer.g[x[-1]])

  losses = {
      "reg": loss_reg_f,
      "td": loss_td_f,
      "mc": loss_mc_f,
  }

  measure_reg = Measures(theta_regress, losses, replay_buffer,
                         results["measure_reg"], mbsize)
  measure_mc = Measures(theta_mc, losses, replay_buffer,
                        results["measure_mc"], mbsize)
  measure_td = Measures(theta_qlearn, losses, replay_buffer,
                        results["measure_td"], mbsize)

  for i in range(100_000):
    sample = replay_buffer.sample(mbsize)
    replay_buffer.compute_value_difference(sample, Qf(sample[0], theta_regress))

    if i and not i % num_measure:
      measure_reg.pre(sample)
      measure_mc.pre(sample)
      measure_td.pre(sample)

    loss_reg = loss_reg_f(sample, theta_regress).mean()
    loss_reg.backward()
    losses_reg.append(loss_reg.item())
    opt_regress.step()
    opt_regress.zero_grad()

    loss_td = loss_td_f(sample, theta_qlearn).mean()
    loss_td.backward()
    losses_td.append(loss_td.item())
    opt_qlearn.step()
    opt_qlearn.zero_grad()

    loss_mc = loss_mc_f(sample, theta_mc).mean()
    loss_mc.backward()
    losses_mc.append(loss_mc.item())
    opt_mc.step()
    opt_mc.zero_grad()

    replay_buffer.update_values(sample, Qf(sample[0], theta_regress))
    if i and not i % num_measure:
      measure_reg.post()
      measure_td.post()
      measure_mc.post()

    if not i % 1000:
      print(i, loss_reg.item(), loss_td.item(), loss_mc.item())
Exemplo n.º 5
0
def load_parameters_from_checkpoint():
  data = pickle.load(open(ARGS.checkpoint, 'rb'))
  return [tf(data[str(i)]) for i in range(10)]
Exemplo n.º 6
0
def main():
    results = {
        "episode": [],
        "measure": [],
        "parameters": [],
    }

    hps = {
        "opt": ARGS.opt,
        "env_name": ARGS.env_name,
        "lr": ARGS.learning_rate,
        "weight_decay": ARGS.weight_decay,
        "run": ARGS.run,
    }
    start_step = ARGS.start_step
    nhid = hps.get("nhid", 32)
    gamma = hps.get("gamma", 0.99)
    mbsize = hps.get("mbsize", 32)
    weight_decay = hps.get("weight_decay", 0)
    sample_near = hps.get("sample_near", "both")
    slice_size = hps.get("slice_size", 0)
    env_name = hps.get("env_name", "ms_pacman")

    clone_interval = hps.get("clone_interval", 10_000)
    reset_on_clone = hps.get("reset_on_clone", False)
    reset_opt_on_clone = hps.get("reset_opt_on_clone", False)
    max_clones = hps.get("max_clones", 2)
    replay_type = hps.get("replay_type", "normal")  # normal, prioritized
    final_epsilon = hps.get("final_epsilon", 0.05)
    num_exploration_steps = hps.get("num_exploration_steps", 500_000)

    lr = hps.get("lr", 1e-4)
    num_iterations = hps.get("num_iterations", 10_000_000)
    buffer_size = hps.get("buffer_size", 250_000)

    seed = hps.get("run", 0) + 1_642_559  # A large prime number
    hps["_seed"] = seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    rng = np.random.RandomState(seed)

    env = AtariEnv(env_name)
    num_act = env.num_actions

    # Define model

    _Qarch, theta_q, Qf, _Qsemi = nn.build(
        nn.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
        nn.conv2d(nhid, nhid * 2, 4, stride=2),
        nn.conv2d(nhid * 2, nhid * 2, 3),
        nn.flatten(),
        nn.hidden(nhid * 2 * 12 * 12, nhid * 16),
        nn.linear(nhid * 16, num_act),
    )

    def make_opt():
        if hps.get("opt", "sgd") == "sgd":
            return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "msgd":
            return torch.optim.SGD(theta_q,
                                   lr,
                                   momentum=hps.get("beta", 0.99),
                                   weight_decay=weight_decay)
        elif hps["opt"] == "rmsprop":
            return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "adam":
            return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay)
        else:
            raise ValueError(hps["opt"])

    opt = make_opt()
    clone_theta_q = lambda: [i.detach().clone() for i in theta_q]

    def copy_theta_q_to_target():
        for i in range(len(theta_q)):
            frozen_theta_q[i] = theta_q[i].detach().clone()

    # Define loss
    def sl1(a, b):
        d = a - b
        u = abs(d)
        s = d**2
        m = (u < s).float()
        return u * m + s * (1 - m)

    td = lambda s, a, r, sp, t, w, tw=theta_q: sl1(
        r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(),
        Qf(s, w)[np.arange(len(a)), a.long()],
    )

    obs = env.reset()

    if replay_type == "normal":
        replay_buffer = ReplayBuffer(seed,
                                     buffer_size,
                                     near_strategy=sample_near)
    elif replay_type == "prioritized":
        replay_buffer = PrioritizedExperienceReplay(seed,
                                                    buffer_size,
                                                    near_strategy=sample_near)

    total_reward = 0
    last_end = 0
    num_fill = 200000
    num_measure = 500
    _t0 = t0 = t1 = t2 = t3 = t4 = time.time()
    tm0 = tm1 = tm2 = tm3 = time.time()
    ema_loss = 0
    last_rewards = [0]

    measure = Measures()
    print("Filling buffer")

    if start_step < num_exploration_steps:
        epsilon = 1 - (start_step / num_exploration_steps) * (1 -
                                                              final_epsilon)
    else:
        epsilon = final_epsilon

    for it in range(num_fill):
        if start_step == 0:
            action = rng.randint(0, num_act)
        else:
            if rng.uniform(0, 1) < epsilon:
                action = rng.randint(0, num_act)
            else:
                action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item()
        obsp, r, done, info = env.step(action)
        replay_buffer.add(obs, action, r, done, env.enumber % 2)
        if replay_type == "prioritized":
            replay_buffer.set_last_priority(
                td(
                    tf(obs / 255.0).unsqueeze(0),
                    tint([action]),
                    r,
                    tf(obsp / 255.0).unsqueeze(0),
                    tf([done]),
                    theta_q,
                    theta_q,
                ))

        obs = obsp
        if done:
            obs = env.reset()

    past_theta = [clone_theta_q()]

    for it in range(start_step, num_iterations):
        do_measure = not it % num_measure
        eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60
        if it and it % 100_000 == 0 or it == num_iterations - 1:
            ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)}
            ps.update({"step": it})
            results["parameters"].append(ps)

        if it % 10_000 == 0:
            print(
                it,
                f"{(t1 - t0)*1000:.2f}, {(t2 - t1)*1000:.2f}, {(t3 - t2)*1000:.2f}, {(t4 - t3)*1000:.2f},",
                f"{(tm1 - tm0)*1000:.2f}, {(tm3 - tm2)*1000:.2f},",
                f"{int(eta//60):2d}h{int(eta%60):02d}m left",
                f":: {ema_loss:.5f}, last 10 rewards: {np.mean(last_rewards):.2f}",
            )
Exemplo n.º 7
0
                f"{(t1 - t0)*1000:.2f}, {(t2 - t1)*1000:.2f}, {(t3 - t2)*1000:.2f}, {(t4 - t3)*1000:.2f},",
                f"{(tm1 - tm0)*1000:.2f}, {(tm3 - tm2)*1000:.2f},",
                f"{int(eta//60):2d}h{int(eta%60):02d}m left",
                f":: {ema_loss:.5f}, last 10 rewards: {np.mean(last_rewards):.2f}",
            )

        t0 = time.time()
        if it < num_exploration_steps:
            epsilon = 1 - (it / num_exploration_steps) * (1 - final_epsilon)
        else:
            epsilon = final_epsilon

        if rng.uniform(0, 1) < epsilon:
            action = rng.randint(0, num_act)
        else:
            action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item()
        t1 = time.time()
        obsp, r, done, info = env.step(action)
        total_reward += r
        replay_buffer.add(obs, action, r, done, env.enumber % 2)
        if replay_type == "prioritized":
            replay_buffer.set_last_priority(
                td(
                    tf(obs / 255.0).unsqueeze(0),
                    tint([action]),
                    r,
                    tf(obsp / 255.0).unsqueeze(0),
                    tf([done]),
                    theta_q,
                    theta_q,
                ))
def main():
    device = torch.device(ARGS.device)
    nn.set_device(device)
    results = {
        "episode": [],
        "measure": [],
        "parameters": [],
    }

    hps = {
        "opt": ARGS.opt,
        "env_name": ARGS.env_name,
        "lr": ARGS.learning_rate,
        "weight_decay": ARGS.weight_decay,
        "run": ARGS.run,
    }

    nhid = hps.get("nhid", 32)
    gamma = hps.get("gamma", 0.99)
    mbsize = ARGS.mbsize
    weight_decay = hps.get("weight_decay", 0)
    sample_near = hps.get("sample_near", "both")
    slice_size = hps.get("slice_size", 0)
    env_name = hps.get("env_name", "ms_pacman")

    clone_interval = ARGS.clone_interval
    reset_on_clone = hps.get("reset_on_clone", False)
    reset_opt_on_clone = hps.get("reset_opt_on_clone", False)
    max_clones = hps.get("max_clones", 2)
    replay_type = hps.get("replay_type", "normal")  # normal, prioritized
    final_epsilon = hps.get("final_epsilon", 0.05)
    num_exploration_steps = hps.get("num_exploration_steps", 500_000)
    Lambda = ARGS.Lambda

    lr = hps.get("lr", 1e-4)
    num_iterations = hps.get("num_iterations", 10_000_000)
    buffer_size = ARGS.buffer_size

    seed = hps.get("run", 0) + 1_642_559  # A large prime number
    hps["_seed"] = seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    rng = np.random.RandomState(seed)

    env = AtariEnv(env_name)
    num_act = env.num_actions

    # Define model

    _Qarch, theta_q, Qf, _Qsemi = nn.build(
        nn.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
        nn.conv2d(nhid, nhid * 2, 4, stride=2),
        nn.conv2d(nhid * 2, nhid * 2, 3),
        nn.flatten(),
        nn.hidden(nhid * 2 * 12 * 12, nhid * 16),
        nn.linear(nhid * 16, num_act),
    )

    def make_opt():
        if hps.get("opt", "sgd") == "sgd":
            return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "msgd":
            return torch.optim.SGD(theta_q,
                                   lr,
                                   momentum=hps.get("beta", 0.99),
                                   weight_decay=weight_decay)
        elif hps["opt"] == "rmsprop":
            return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "adam":
            return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay)
        else:
            raise ValueError(hps["opt"])

    opt = make_opt()
    clone_theta_q = lambda: [i.detach().clone() for i in theta_q]

    def copy_theta_q_to_target():
        for i in range(len(theta_q)):
            frozen_theta_q[i] = theta_q[i].detach().clone()

    # Define loss
    def sl1(a, b):
        d = a - b
        u = abs(d)
        s = d**2
        m = (u < s).float()
        return u * m + s * (1 - m)

    td = lambda x: sl1(
        x.r +
        (1 - x.t.float()) * gamma * Qf(x.sp, past_theta[0]).max(1)[0].detach(),
        Qf(x.s, theta_q)[np.arange(len(x.a)), x.a.long()],
    )

    tdQL = lambda x: sl1(
        Qf(x.s, theta_q)[np.arange(len(x.a)), x.a.long()], x.lg)

    mc = lambda x: sl1(Qf(x.s, theta_q).max(1)[0], x.g)

    past_theta = [clone_theta_q()]

    replay_buffer = ReplayBufferV2(seed, buffer_size, lambda s: Qf(s, theta_q),
                                   lambda s: Qf(s, past_theta[0]).max(1)[0],
                                   Lambda, gamma)

    total_reward = 0
    last_end = 0
    num_fill = buffer_size // 2
    num_measure = 500
    _t0 = t0 = t1 = t2 = t3 = t4 = time.time()
    tm0 = tm1 = tm2 = tm3 = time.time()
    ema_loss = 0
    last_rewards = [0]

    measure = Measures(theta_q, {
        "td": td,
        "tdQL": tdQL,
        "mc": mc,
    }, replay_buffer, results["measure"], 32)

    obs = env.reset()
    for it in range(num_fill):
        action = rng.randint(0, num_act)
        obsp, r, done, info = env.step(action)
        replay_buffer.add(obs, action, r, done)

        obs = obsp
        if done:
            print(it)
            obs = env.reset()

    for it in range(num_iterations):
        do_measure = not it % num_measure
        eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60
        if it and it % 100_000 == 0 or it == num_iterations - 1:
            ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)}
            ps.update({"step": it})
            results["parameters"].append(ps)

        if it < num_exploration_steps:
            epsilon = 1 - (it / num_exploration_steps) * (1 - final_epsilon)
        else:
            epsilon = final_epsilon

        if rng.uniform(0, 1) < epsilon:
            action = rng.randint(0, num_act)
        else:
            action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item()

        obsp, r, done, info = env.step(action)
        total_reward += r
        replay_buffer.add(obs, action, r, done)

        obs = obsp
        if done:
            obs = env.reset()
            results["episode"].append({
                "end": it,
                "start": last_end,
                "total_reward": total_reward
            })
            last_end = it
            last_rewards = [total_reward] + last_rewards[:10]
            total_reward = 0

        sample = replay_buffer.sample(mbsize)
        with torch.no_grad():
            v_before = Qf(sample.s, theta_q).detach()

        loss = tdQL(sample)

        if do_measure:
            tm0 = time.time()
            measure.pre(sample)
            tm1 = time.time()
        loss = loss.mean()
        loss.backward()
        opt.step()
        opt.zero_grad()

        with torch.no_grad():
            v_after = Qf(sample.s, theta_q).detach()
        replay_buffer.compute_value_difference(sample, v_before, v_after)

        if do_measure:
            tm2 = time.time()
            measure.post()
            tm3 = time.time()
        t4 = time.time()
        if it and clone_interval and it % clone_interval == 0:
            past_theta = [clone_theta_q()]  #+ past_theta[:max_clones - 1]
            replay_buffer.recompute_lambda_returns()

        #exp_results["loss"].append(loss.item())
        ema_loss = 0.999 * ema_loss + 0.001 * loss.item()
Exemplo n.º 9
0
                f":: {ema_loss:.5f}, last 10 rewards: {np.mean(last_rewards):.2f}",
                #" " * 20,
                #end="\r",
            )
            #sys.stdout.flush()

        t0 = time.time()
        if it < num_exploration_steps:
            epsilon = 1 - (it / num_exploration_steps) * (1 - final_epsilon)
        else:
            epsilon = final_epsilon

        if rng.uniform(0, 1) < epsilon:
            action = rng.randint(0, num_act)
        else:
            action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item()
        t1 = time.time()
        obsram = env.getRAM().tostring()
        if obsram not in test_set:
            test_set[obsram] = float(rng.uniform(0, 1) < to_test_prob)
        obsp, r, done, info = env.step(action)
        total_reward += r
        replay_buffer.add(obs, action, r, done, env.enumber % 2)
        replay_buffer.set_last_priority(1 - test_set[obsram])

        obs = obsp
        if done:
            replay_buffer.add(obs, 0, 0, done, env.enumber % 2)
            replay_buffer.set_last_priority(0)
            obs = env.reset()
            results["episode"].append({