def main():
  gamma = 0.99
  hps = pickle.load(open(ARGS.checkpoint, 'rb'))['hps']
  env_name = hps["env_name"]
  if 'Lambda' in hps:
    Lambda = hps['Lambda']
  else:
    Lambda = 0

  device = torch.device(ARGS.device)
  nn.set_device(device)
  replay_buffer = ReplayBuffer(ARGS.run, ARGS.buffer_size)
  Qf, theta_q = fill_buffer_with_expert(replay_buffer, env_name)
  for p in theta_q:
    p.requires_grad = True
  if Lambda > 0:
    replay_buffer.compute_episode_boundaries()
    replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda, gamma)

  td__ = lambda s, a, r, sp, t, idx, w, tw: sl1(
      r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(),
      Qf(s, w)[np.arange(len(a)), a.long()],
  )

  td = lambda s, a, r, sp, t, idx, w, tw: Qf(s, w).max(1)[0]

  tdL = lambda s, a, r, sp, t, idx, w, tw: sl1(
      Qf(s, w)[:, 0], replay_buffer.LG[idx])

  loss_func = {
      'td': td, 'tdL': tdL}[ARGS.loss_func]

  opt = torch.optim.SGD(theta_q, 1)

  def grad_sim(inp, grad):
    dot = sum([(p.grad * gp).sum() for p, gp in zip(inp, grad)])
    nA = torch.sqrt(sum([(p.grad**2).sum() for p, gp in zip(inp, grad)]))
    nB = torch.sqrt(sum([(gp**2).sum() for p, gp in zip(inp, grad)]))
    return (dot / (nA * nB)).item()

  relevant_features = np.int32(
      sorted(list(atari_dict[env_name.replace("_", "")].values())))
  sims = []
  ram_sims = []
  for i in range(2000):
    sim = []
    *sample, idx = replay_buffer.sample(1)
    loss = loss_func(*sample, idx, theta_q, theta_q).mean()
    loss.backward()
    g0 = [p.grad + 0 for p in theta_q]
    for j in range(-30, 31):
      opt.zero_grad()
      loss = loss_func(*replay_buffer.get(idx + j), theta_q, theta_q).mean()
      loss.backward()
      sim.append(grad_sim(theta_q, g0))
    sims.append(np.float32(sim))
    for j in range(200):
      opt.zero_grad()
      *sample_j, idx_j = replay_buffer.sample(1)
      loss = loss_func(*sample_j, idx_j, theta_q, theta_q).mean()
      loss.backward()
      ram_sims.append(
          (grad_sim(theta_q, g0),
           abs(replay_buffer.ram[idx[0]][relevant_features].float() -
               replay_buffer.ram[idx_j[0]][relevant_features].float()).mean()))
    opt.zero_grad()
  ram_sims = np.float32(
      ram_sims)  #np.histogram(np.float32(ram_sim), 100, (-1, 1))

  # Compute "True" gradient
  grads = [i.detach() * 0 for i in theta_q]
  N = 0
  for samples in replay_buffer.in_order_iterate(ARGS.mbsize * 8):
    loss = loss_func(*samples, theta_q, theta_q).mean()
    loss.backward()
    N += samples[0].shape[0]
    for p, gp in zip(theta_q, grads):
      gp.data.add_(p.grad)
    opt.zero_grad()

  dots = []
  i = 0
  for sample in replay_buffer.in_order_iterate(1):
    loss = loss_func(*sample, theta_q, theta_q).mean()
    loss.backward()
    dots.append(grad_sim(theta_q, grads))
    opt.zero_grad()
    i += 1
  histo = np.histogram(dots, 100, (-1, 1))

  results = {
      "grads": [i.cpu().data.numpy() for i in grads],
      "sims": np.float32(sims),
      "histo": histo,
      "ram_sims": ram_sims,
  }

  path = f'results/grads_{ARGS.checkpoint}.pkl'
  with open(path, "wb") as f:
    pickle.dump(results, f)
Exemple #2
0
def main():
    device = torch.device(ARGS.device)
    mm.set_device(device)
    results = {
        "measure": [],
        "parameters": [],
    }

    seed = ARGS.run + 1_642_559  # A large prime number
    torch.manual_seed(seed)
    np.random.seed(seed)
    rng = np.random.RandomState(seed)
    env = AtariEnv(ARGS.env_name)
    mbsize = ARGS.mbsize
    Lambda = ARGS.Lambda
    nhid = 32
    num_measure = 1000
    gamma = 0.99
    clone_interval = ARGS.clone_interval
    num_iterations = ARGS.num_iterations

    num_Q_outputs = 1
    # Model
    _Qarch, theta_q, Qf, _Qsemi = mm.build(
        mm.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
        mm.conv2d(nhid, nhid * 2, 4, stride=2),
        mm.conv2d(nhid * 2, nhid * 2, 3),
        mm.flatten(),
        mm.hidden(nhid * 2 * 12 * 12, nhid * 16),
        mm.linear(nhid * 16, num_Q_outputs),
    )
    clone_theta_q = lambda: [i.detach().clone() for i in theta_q]
    theta_target = clone_theta_q()
    opt = make_opt(ARGS.opt, theta_q, ARGS.learning_rate, ARGS.weight_decay)

    # Replay Buffer
    replay_buffer = ReplayBuffer(seed, ARGS.buffer_size)

    # Losses
    td = lambda s, a, r, sp, t, idx, w, tw: sl1(
        r + (1 - t.float()) * gamma * Qf(sp, tw)[:, 0].detach(),
        Qf(s, w)[:, 0],
    )

    tdL = lambda s, a, r, sp, t, idx, w, tw: sl1(
        Qf(s, w)[:, 0], replay_buffer.LG[idx])

    mc = lambda s, a, r, sp, t, idx, w, tw: sl1(
        Qf(s, w)[:, 0], replay_buffer.g[idx])

    # Define metrics
    measure = Measures(
        theta_q, {
            "td": lambda x, w: td(*x, w, theta_target),
            "tdL": lambda x, w: tdL(*x, w, theta_target),
            "mc": lambda x, w: mc(*x, w, theta_target),
        }, replay_buffer, results["measure"], 32)

    # Get expert trajectories
    rand_classes = fill_buffer_with_expert(env, replay_buffer)
    # Compute initial values
    replay_buffer.compute_values(lambda s: Qf(s, theta_q), num_Q_outputs)
    replay_buffer.compute_returns(gamma)
    replay_buffer.compute_reward_distances()
    replay_buffer.compute_episode_boundaries()
    replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda,
                                         gamma)

    # Run policy evaluation
    for it in range(num_iterations):
        do_measure = not it % num_measure
        sample = replay_buffer.sample(mbsize)

        if do_measure:
            measure.pre(sample)
        replay_buffer.compute_value_difference(sample, Qf(sample[0], theta_q))

        loss = tdL(*sample, theta_q, theta_target)
        loss = loss.mean()
        loss.backward()
        opt.step()
        opt.zero_grad()

        replay_buffer.update_values(sample, Qf(sample[0], theta_q))
        if do_measure:
            measure.post()

        if it and clone_interval and it % clone_interval == 0:
            theta_target = clone_theta_q()
            replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q),
                                                 Lambda, gamma)

        if it and it % clone_interval == 0 or it == num_iterations - 1:
            ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)}
            ps.update({"step": it})
            results["parameters"].append(ps)

    with open(f'results/td_lambda_{run}.pkl', 'wb') as f:
        pickle.dump(results, f)