def main():
  device = torch.device(ARGS.device)
  mm.set_device(device)
  results = {
      "measure": [],
      "parameters": [],
      "args": ARGS,
  }
  print(ARGS)
  seed = ARGS.run + 1_642_559  # A large prime number
  torch.manual_seed(seed)
  np.random.seed(seed)
  rng = np.random.RandomState(seed)
  env = AtariEnv(ARGS.env_name)
  mbsize = ARGS.mbsize
  nhid = 32
  num_measure = 1000
  gamma = 0.99
  tau = 0.01
  clone_interval = ARGS.clone_interval
  num_iterations = ARGS.num_iterations

  num_Q_outputs = env.num_actions if ARGS.loss_func != 'rand' else ARGS.num_rand_classes
  # Model
  act = torch.nn.LeakyReLU()
  Qf = torch.nn.Sequential(torch.nn.Conv2d(4, nhid, 8, stride=4, padding=4), act,
                           torch.nn.Conv2d(nhid, nhid*2, 4, stride=2,padding=2), act,
                           torch.nn.Conv2d(nhid*2, nhid*2, 3,padding=1), act,
                           torch.nn.Flatten(),
                           torch.nn.Linear(nhid*2*12*12, nhid*16), act,
                           torch.nn.Linear(nhid*16, num_Q_outputs))
  Qf.to(device)
  Qf.apply(init_weights)
  if ARGS.loss_func == 'nfdqn':
    Qf_target = Qf
  else:
    Qf_target = copy.deepcopy(Qf)
  Qf = extend(Qf)

  opt = make_opt(ARGS.opt, Qf.parameters(), ARGS.learning_rate, ARGS.weight_decay)

  # Replay Buffer
  replay_buffer = ReplayBufferV2(seed, ARGS.buffer_size,
                                 value_callback=lambda s: Qf(s),
                                 Lambda=0)

  td = lambda x: sl1(
      x.r + (1 - x.t.float()) * gamma * Qf_target(x.sp).max(1)[0].detach(),
      Qf(x.s)[np.arange(len(x.a)), x.a.long()],
  )
  sarsa = lambda x: sl1(
      x.r + ((1 - x.t.float())
             * gamma
             * Qf_target(x.sp)[np.arange(len(x.ap)), x.ap.long()].detach()),
      Qf(x.s)[np.arange(len(x.a)), x.a.long()],
  )


  mc = lambda x: sl1(
      Qf(x.s).max(1)[0], x.g)


  if ARGS.loss_func == 'rand':
    raise ValueError('fixme Qf')

    def rand_nll(s, a, r, sp, t, idx, w, tw):
      return F.cross_entropy(Qf(s, w), tint(rand_classes[idx]), reduce=False)
    def rand_acc(s, a, r, sp, t, idx, w, tw):
      return (Qf(s, w).argmax(1) != tint(rand_classes[idx])).float()

    # Define metrics
    measure = Measures(
        theta_q, {
            "rand_nll": lambda x, w: rand_nll(*x, w, theta_target),
            "rand_acc": lambda x, w: rand_acc(*x, w, theta_target),
        }, replay_buffer, results["measure"], 32)

    loss_func = rand_nll
  else:
    # Define metrics
    measure = Measures(
        list(Qf.parameters()), {
            "td": td,
            "func": lambda x: Qf(x.s).max(1).values,
            #"sarsa": sarsa,
            #"mc": mc,
        }, replay_buffer, results["measure"], 32,
        lambda x: Qf(x.s),
        Qf)

    loss_func = {
        "sarsa": sarsa,
        "qlearn": td,
        "mc": mc,
        'ddqn': td,
        'nfdqn': td,
    }[ARGS.loss_func]


  # Get expert trajectories
  fill_buffer_with_expert(env, replay_buffer)

  # Run policy evaluation
  for it in tqdm(range(num_iterations), smoothing=0):
    do_measure = not it % num_measure
    sample = replay_buffer.sample(mbsize)

    if do_measure:
      measure.pre(sample)
    #v_before = Qf(sample[0])

    opt.zero_grad()
    loss = loss_func(sample)
    loss = loss.mean()
    loss.backward()
    opt.step()

    #replay_buffer.update_values(sample, v_before, Qf(sample[0], theta_q))
    if do_measure:
      measure.post()

    if it and clone_interval and it % clone_interval == 0:
      if ARGS.loss_func in ['td']:
        Qf_target = copy.deepcopy(Qf)
    if ARGS.loss_func in ['ddqn']:
      for target_param, param in zip(Qf_target.parameters(), Qf.parameters()):
        target_param.data.copy_(tau * param + (1 - tau) * target_param)

    if it and it % clone_interval == 0 and False or it == num_iterations - 1:
      ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(Qf.parameters())}
      ps.update({"step": it})
      results["parameters"].append(ps)

  with open(f'results/pol_eval_{ARGS.run}.pkl', 'wb') as f:
    pickle.dump(results, f)
def main():
    device = torch.device(ARGS.device)
    nn.set_device(device)
    results = {
        "episode": [],
        "measure": [],
        "parameters": [],
    }

    hps = {
        "opt": ARGS.opt,
        "env_name": ARGS.env_name,
        "lr": ARGS.learning_rate,
        "weight_decay": ARGS.weight_decay,
        "run": ARGS.run,
    }

    nhid = hps.get("nhid", 32)
    gamma = hps.get("gamma", 0.99)
    mbsize = ARGS.mbsize
    weight_decay = hps.get("weight_decay", 0)
    sample_near = hps.get("sample_near", "both")
    slice_size = hps.get("slice_size", 0)
    env_name = hps.get("env_name", "ms_pacman")

    clone_interval = ARGS.clone_interval
    reset_on_clone = hps.get("reset_on_clone", False)
    reset_opt_on_clone = hps.get("reset_opt_on_clone", False)
    max_clones = hps.get("max_clones", 2)
    replay_type = hps.get("replay_type", "normal")  # normal, prioritized
    final_epsilon = hps.get("final_epsilon", 0.05)
    num_exploration_steps = hps.get("num_exploration_steps", 500_000)
    Lambda = ARGS.Lambda

    lr = hps.get("lr", 1e-4)
    num_iterations = hps.get("num_iterations", 10_000_000)
    buffer_size = ARGS.buffer_size

    seed = hps.get("run", 0) + 1_642_559  # A large prime number
    hps["_seed"] = seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    rng = np.random.RandomState(seed)

    env = AtariEnv(env_name)
    num_act = env.num_actions

    # Define model

    _Qarch, theta_q, Qf, _Qsemi = nn.build(
        nn.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
        nn.conv2d(nhid, nhid * 2, 4, stride=2),
        nn.conv2d(nhid * 2, nhid * 2, 3),
        nn.flatten(),
        nn.hidden(nhid * 2 * 12 * 12, nhid * 16),
        nn.linear(nhid * 16, num_act),
    )

    def make_opt():
        if hps.get("opt", "sgd") == "sgd":
            return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "msgd":
            return torch.optim.SGD(theta_q,
                                   lr,
                                   momentum=hps.get("beta", 0.99),
                                   weight_decay=weight_decay)
        elif hps["opt"] == "rmsprop":
            return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "adam":
            return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay)
        else:
            raise ValueError(hps["opt"])

    opt = make_opt()
    clone_theta_q = lambda: [i.detach().clone() for i in theta_q]

    def copy_theta_q_to_target():
        for i in range(len(theta_q)):
            frozen_theta_q[i] = theta_q[i].detach().clone()

    # Define loss
    def sl1(a, b):
        d = a - b
        u = abs(d)
        s = d**2
        m = (u < s).float()
        return u * m + s * (1 - m)

    td = lambda x: sl1(
        x.r +
        (1 - x.t.float()) * gamma * Qf(x.sp, past_theta[0]).max(1)[0].detach(),
        Qf(x.s, theta_q)[np.arange(len(x.a)), x.a.long()],
    )

    tdQL = lambda x: sl1(
        Qf(x.s, theta_q)[np.arange(len(x.a)), x.a.long()], x.lg)

    mc = lambda x: sl1(Qf(x.s, theta_q).max(1)[0], x.g)

    past_theta = [clone_theta_q()]

    replay_buffer = ReplayBufferV2(seed, buffer_size, lambda s: Qf(s, theta_q),
                                   lambda s: Qf(s, past_theta[0]).max(1)[0],
                                   Lambda, gamma)

    total_reward = 0
    last_end = 0
    num_fill = buffer_size // 2
    num_measure = 500
    _t0 = t0 = t1 = t2 = t3 = t4 = time.time()
    tm0 = tm1 = tm2 = tm3 = time.time()
    ema_loss = 0
    last_rewards = [0]

    measure = Measures(theta_q, {
        "td": td,
        "tdQL": tdQL,
        "mc": mc,
    }, replay_buffer, results["measure"], 32)

    obs = env.reset()
    for it in range(num_fill):
        action = rng.randint(0, num_act)
        obsp, r, done, info = env.step(action)
        replay_buffer.add(obs, action, r, done)

        obs = obsp
        if done:
            print(it)
            obs = env.reset()

    for it in range(num_iterations):
        do_measure = not it % num_measure
        eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60
        if it and it % 100_000 == 0 or it == num_iterations - 1:
            ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)}
            ps.update({"step": it})
            results["parameters"].append(ps)

        if it < num_exploration_steps:
            epsilon = 1 - (it / num_exploration_steps) * (1 - final_epsilon)
        else:
            epsilon = final_epsilon

        if rng.uniform(0, 1) < epsilon:
            action = rng.randint(0, num_act)
        else:
            action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item()

        obsp, r, done, info = env.step(action)
        total_reward += r
        replay_buffer.add(obs, action, r, done)

        obs = obsp
        if done:
            obs = env.reset()
            results["episode"].append({
                "end": it,
                "start": last_end,
                "total_reward": total_reward
            })
            last_end = it
            last_rewards = [total_reward] + last_rewards[:10]
            total_reward = 0

        sample = replay_buffer.sample(mbsize)
        with torch.no_grad():
            v_before = Qf(sample.s, theta_q).detach()

        loss = tdQL(sample)

        if do_measure:
            tm0 = time.time()
            measure.pre(sample)
            tm1 = time.time()
        loss = loss.mean()
        loss.backward()
        opt.step()
        opt.zero_grad()

        with torch.no_grad():
            v_after = Qf(sample.s, theta_q).detach()
        replay_buffer.compute_value_difference(sample, v_before, v_after)

        if do_measure:
            tm2 = time.time()
            measure.post()
            tm3 = time.time()
        t4 = time.time()
        if it and clone_interval and it % clone_interval == 0:
            past_theta = [clone_theta_q()]  #+ past_theta[:max_clones - 1]
            replay_buffer.recompute_lambda_returns()

        #exp_results["loss"].append(loss.item())
        ema_loss = 0.999 * ema_loss + 0.001 * loss.item()
Example #3
0
        self.near_td_after = self._td(*self.near_samples)
        self.near_td_gain = ((self.near_td_before -
                              self.near_td_after).cpu().data.numpy())
        self.near_td_gain_avg = self.near_td_gain.mean().item()

        self.sample_td_after = self._td(*e["sample"])
        self.sample_td_gain = ((self.sample_td_before -
                                self.sample_td_after).cpu().data.numpy())
        self.sample_td_gain_avg = self.sample_td_gain.mean().item()

    def log(self, rs):
        e = inspect.currentframe().f_back.f_locals  # Don't do this at home
        rs.append({
            "td_error": self.sample_td_before.cpu().data.numpy(),
            "other_td_gain": self.other_td_gain,
            "other_td_gain_avg": self.other_td_gain_avg,
            "near_td_gain": self.near_td_gain,
            "near_td_gain_avg": self.near_td_gain_avg,
            "sample_td_gain": self.sample_td_gain,
            "sample_td_gain_avg": self.sample_td_gain_avg,
            "idx": e["idx"].cpu().data.numpy(),
            "step": e["it"],
        })


if __name__ == "__main__":
    ARGS = parser.parse_args()
    device = torch.device(ARGS.device)
    nn.set_device(device)
    main()
def main(args):
  device = torch.device(args.device)
  mm.set_device(device)
  results_conn = lmdb.open(f'{args.save_path}/run_{args.run}', map_size=int(16 * 2 ** 30))
  params_conn = lmdb.open(f'{args.save_path}/run_{args.run}/params', map_size=int(16 * 2 ** 30))
  with results_conn.begin(write=True) as txn:
    txn.put(b'args', packobj(args))

  print(args)
  seed = args.run + 1_642_559  # A large prime number
  torch.manual_seed(seed)
  np.random.seed(seed)
  rng = np.random.RandomState(seed)
  env = AtariEnv(args.env_name)
  mbsize = args.mbsize
  nhid = args.nhid
  gamma = 0.99
  #env.r_gamma = gamma
  checkpoint_freq = args.checkpoint_freq
  test_freq = args.test_freq
  target_tau = args.target_tau
  target_clone_interval = args.target_clone_interval
  target_type = args.target_type
  num_iterations = args.num_iterations
  num_Q_outputs = env.num_actions

  # Model
  act = torch.nn.LeakyReLU()
  if args.body_type == 'normal':
    body = torch.nn.Sequential(torch.nn.Conv2d(4, nhid, 8, stride=4, padding=4), act,
                               torch.nn.Conv2d(nhid, nhid*2, 4, stride=2,padding=2), act,
                               torch.nn.Conv2d(nhid*2, nhid*2, 3,padding=1), act,
                               torch.nn.Flatten(),
                               torch.nn.Linear(nhid*2*12*12, nhid*16), act)
  elif args.body_type == 'tiny':
    body = torch.nn.Sequential(torch.nn.Conv2d(4, nhid, 3, stride=2, padding=1), act, # 42
                               torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1), act, # 21
                               torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1), act, # 11
                               torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1), act, # 6
                               torch.nn.Conv2d(nhid, num_Q_outputs, 6), # 1
                               torch.nn.Flatten())

  if args.head_type == 'normal':
    head = torch.nn.Sequential(torch.nn.Linear(nhid*16, num_Q_outputs))
  elif args.head_type == 'slim': # Slim end to we can do block diagonal zeta
    head = torch.nn.Sequential(torch.nn.Linear(nhid*16, nhid), act,
                               torch.nn.Linear(nhid, nhid), act,
                               torch.nn.Linear(nhid, num_Q_outputs))
  elif args.head_type == 'slim2':
    head = torch.nn.Sequential(torch.nn.Linear(nhid*16, nhid * 2), act,
                               torch.nn.Linear(nhid * 2, num_Q_outputs))
  elif args.head_type == 'none':
    head = torch.nn.Sequential()

  Qf = torch.nn.Sequential(body, head)

  Qf.to(device)
  Qf.apply(init_weights)
  if args.target_type == 'none':
    Qf_target = Qf
  else:
    Qf_target = copy.deepcopy(Qf)

  opt = make_opt(args, Qf.parameters())
  do_set_predictions = args.opt == 'msgd_corr'

  # Replay Buffer
  replay_buffer = ReplayBufferV2(seed, args.buffer_size)
  test_replay_buffer = ReplayBufferV2(seed, 10000)

  # Get expert trajectories
  expert = load_expert(args.env_name, env)
  fill_buffer_with_expert(expert, env, replay_buffer)
  fill_buffer_with_expert(expert, env, test_replay_buffer)

  ar = lambda x: torch.arange(x.shape[0], device=x.device)

  losses = []
  num_iterations = 1 + num_iterations
  ignore_vprime = bool(args.opt_ignore_vprime)
  # Run policy evaluation
  for it in (tqdm(range(num_iterations), smoothing=0) if args.progress else range(num_iterations)):
    sample = replay_buffer.sample(mbsize)

    q = Qf(sample.s)
    v = q[ar(q), sample.a.long()]
    vp = Qf_target(sample.sp)[ar(q), sample.ap.long()] # Sarsa updat
    gvp = (1 - sample.t.float()) * gamma * vp
    loss = (v - (sample.r + gvp.detach())).pow(2)
    _loss = loss
    if do_set_predictions:
      opt.set_predictions(v.mean(), gvp.mean() if not ignore_vprime else None)

    loss = loss.mean()
    loss.backward(retain_graph=True)
    opt.step()
    opt.zero_grad()

    losses.append(loss.item())

    if target_type == 'frozen' and it % target_clone_interval == 0:
      Qf_target = copy.deepcopy(Qf)
    elif target_type == 'moving':
      for target_param, param in zip(Qf_target.parameters(), Qf.parameters()):
        target_param.data.mul_(1-target_tau).add_(param, alpha=target_tau)

    if it % checkpoint_freq == 0 and args.save_parameters:
      with params_conn.begin(write=True) as txn:
        txn.put(f'parameters_{it}'.encode(), packobj(Qf.state_dict()))

    if it % test_freq == 0:
      expert_q_loss = 0
      expert_v_loss = 0
      mc_loss = 0
      n = 0
      with torch.no_grad():
        for sample in test_replay_buffer.iterate(512):
          n += sample.s.shape[0]
          q = Qf(sample.s)[ar(sample.a), sample.a.long()]
          mc_loss += (q - sample.g).pow(2).sum().item()
      print(q.shape, sample.g.shape)
      with results_conn.begin(write=True) as txn:
        txn.put(f'expert-loss_{it}'.encode(), packobj((expert_q_loss/n,
                                                       expert_v_loss/n,
                                                       mc_loss/n)))
        if it > 0:
          txn.put(f'train-loss_{it}'.encode(), packobj(losses))
        print(it, np.mean(losses), (expert_q_loss/n, expert_v_loss/n, mc_loss/n))
        losses = []


    if np.isnan(loss.item()):
      print("Learning has diverged, nan loss")
      with results_conn.begin(write=True) as txn:
        txn.put(b'diverged', b'True')
      break
  print("Done.")
Example #5
0
def main(argv):
    results = {
        "episode": [],
        "measure": [],
        "parameters": [],
    }
    device = torch.device(ARGS.device)
    nn.set_device(device)

    hps = {
        "opt": ARGS.opt,
        "env_name": ARGS.env_name,
        "lr": ARGS.learning_rate,
        "weight_decay": ARGS.weight_decay,
        "run": ARGS.run,
        "mbsize": ARGS.mbsize,
    }
    start_step = 0
    nhid = hps.get("nhid", 32)
    gamma = hps.get("gamma", 0.99)
    mbsize = hps.get("mbsize", 32)
    weight_decay = hps.get("weight_decay", 0)
    sample_near = hps.get("sample_near", "both")
    slice_size = hps.get("slice_size", 0)
    env_name = hps.get("env_name", "ms_pacman")

    clone_interval = hps.get("clone_interval", 10_000)
    reset_on_clone = hps.get("reset_on_clone", False)
    reset_opt_on_clone = hps.get("reset_opt_on_clone", False)
    max_clones = hps.get("max_clones", 2)
    target = hps.get("target", "last")  # self, last, clones
    replay_type = hps.get("replay_type", "normal")  # normal, prioritized
    final_epsilon = hps.get("final_epsilon", 0.05)
    num_exploration_steps = hps.get("num_exploration_steps", 500_000)

    lr = hps.get("lr", 1e-4)
    num_iterations = hps.get("num_iterations", 10_000_000)
    buffer_size = hps.get("buffer_size", ARGS.buffer_size)
    to_test_prob = ARGS.to_test_prob

    seed = hps.get("run", 0) + 1_642_559  # A large prime number
    hps["_seed"] = seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    rng = np.random.RandomState(seed)

    env = AtariEnv(env_name)
    num_act = env.num_actions

    # Define model

    _Qarch, theta_q, Qf, _Qsemi = nn.build(
        nn.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
        nn.conv2d(nhid, nhid * 2, 4, stride=2),
        nn.conv2d(nhid * 2, nhid * 2, 3),
        nn.flatten(),
        nn.hidden(nhid * 2 * 12 * 12, nhid * 16),
        nn.linear(nhid * 16, num_act),
    )

    def make_opt():
        if hps.get("opt", "sgd") == "sgd":
            return torch.optim.SGD(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "msgd":
            return torch.optim.SGD(theta_q,
                                   lr,
                                   momentum=hps.get("beta", 0.99),
                                   weight_decay=weight_decay)
        elif hps["opt"] == "rmsprop":
            return torch.optim.RMSprop(theta_q, lr, weight_decay=weight_decay)
        elif hps["opt"] == "adam":
            return torch.optim.Adam(theta_q, lr, weight_decay=weight_decay)
        else:
            raise ValueError(hps["opt"])

    opt = make_opt()
    clone_theta_q = lambda: [i.detach().clone() for i in theta_q]

    def copy_theta_q_to_target():
        for i in range(len(theta_q)):
            frozen_theta_q[i] = theta_q[i].detach().clone()

    # Define loss
    def sl1(a, b):
        d = a - b
        u = abs(d)
        s = d**2
        m = (u < s).float()
        return u * m + s * (1 - m)

    td = lambda s, a, r, sp, t, w, tw=theta_q: sl1(
        r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(),
        Qf(s, w)[np.arange(len(a)), a.long()],
    )

    obs = env.reset()

    replay_buffer = PrioritizedExperienceReplay(seed,
                                                buffer_size,
                                                near_strategy=sample_near)
    test_set = {}

    last_lock_refresh = time.time()
    total_reward = 0
    last_end = 0
    num_fill = min(200000, replay_buffer.size)
    num_measure = 500
    _t0 = t0 = t1 = t2 = t3 = t4 = time.time()
    tm0 = tm1 = tm2 = tm3 = time.time()
    ema_loss = 0
    last_rewards = [0]

    measure = Measures()
    print("Filling buffer")

    if start_step < num_exploration_steps:
        epsilon = 1 - (start_step / num_exploration_steps) * (1 -
                                                              final_epsilon)
    else:
        epsilon = final_epsilon

    for it in range(num_fill):
        if start_step == 0:
            action = rng.randint(0, num_act)
        else:
            if rng.uniform(0, 1) < epsilon:
                action = rng.randint(0, num_act)
            else:
                action = Qf(tf(obs / 255.0).unsqueeze(0)).argmax().item()
        obsram = env.getRAM().tostring()
        if obsram not in test_set:
            test_set[obsram] = float(rng.uniform(0, 1) < to_test_prob)
        obsp, r, done, info = env.step(action)
        replay_buffer.add(obs, action, r, done, env.enumber % 2)
        replay_buffer.set_last_priority(1 - test_set[obsram])

        obs = obsp
        if done:
            replay_buffer.set_last_priority(1 - test_set[obsram])
            obs = env.reset()

    past_theta = [clone_theta_q()]

    for it in range(start_step, num_iterations):
        do_measure = not it % num_measure
        eta = (time.time() - _t0) / (it + 1) * (num_iterations - it) / 60
        if it and it % 100_000 == 0 or it == num_iterations - 1:
            ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)}
            ps.update({"step": it})
            results["parameters"].append(ps)

        if it % 10_000 == 0:
            print(
                it,
                f"{(t1 - t0)*1000:.2f}, {(t2 - t1)*1000:.2f}, {(t3 - t2)*1000:.2f}, {(t4 - t3)*1000:.2f},",
                f"{(tm1 - tm0)*1000:.2f}, {(tm3 - tm2)*1000:.2f},",
                f"{int(eta//60):2d}h{int(eta%60):02d}m left",
                f":: {ema_loss:.5f}, last 10 rewards: {np.mean(last_rewards):.2f}",
                #" " * 20,
                #end="\r",
            )
def main():
  gamma = 0.99
  hps = pickle.load(open(ARGS.checkpoint, 'rb'))['hps']
  env_name = hps["env_name"]
  if 'Lambda' in hps:
    Lambda = hps['Lambda']
  else:
    Lambda = 0

  device = torch.device(ARGS.device)
  nn.set_device(device)
  replay_buffer = ReplayBuffer(ARGS.run, ARGS.buffer_size)
  Qf, theta_q = fill_buffer_with_expert(replay_buffer, env_name)
  for p in theta_q:
    p.requires_grad = True
  if Lambda > 0:
    replay_buffer.compute_episode_boundaries()
    replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda, gamma)

  td__ = lambda s, a, r, sp, t, idx, w, tw: sl1(
      r + (1 - t.float()) * gamma * Qf(sp, tw).max(1)[0].detach(),
      Qf(s, w)[np.arange(len(a)), a.long()],
  )

  td = lambda s, a, r, sp, t, idx, w, tw: Qf(s, w).max(1)[0]

  tdL = lambda s, a, r, sp, t, idx, w, tw: sl1(
      Qf(s, w)[:, 0], replay_buffer.LG[idx])

  loss_func = {
      'td': td, 'tdL': tdL}[ARGS.loss_func]

  opt = torch.optim.SGD(theta_q, 1)

  def grad_sim(inp, grad):
    dot = sum([(p.grad * gp).sum() for p, gp in zip(inp, grad)])
    nA = torch.sqrt(sum([(p.grad**2).sum() for p, gp in zip(inp, grad)]))
    nB = torch.sqrt(sum([(gp**2).sum() for p, gp in zip(inp, grad)]))
    return (dot / (nA * nB)).item()

  relevant_features = np.int32(
      sorted(list(atari_dict[env_name.replace("_", "")].values())))
  sims = []
  ram_sims = []
  for i in range(2000):
    sim = []
    *sample, idx = replay_buffer.sample(1)
    loss = loss_func(*sample, idx, theta_q, theta_q).mean()
    loss.backward()
    g0 = [p.grad + 0 for p in theta_q]
    for j in range(-30, 31):
      opt.zero_grad()
      loss = loss_func(*replay_buffer.get(idx + j), theta_q, theta_q).mean()
      loss.backward()
      sim.append(grad_sim(theta_q, g0))
    sims.append(np.float32(sim))
    for j in range(200):
      opt.zero_grad()
      *sample_j, idx_j = replay_buffer.sample(1)
      loss = loss_func(*sample_j, idx_j, theta_q, theta_q).mean()
      loss.backward()
      ram_sims.append(
          (grad_sim(theta_q, g0),
           abs(replay_buffer.ram[idx[0]][relevant_features].float() -
               replay_buffer.ram[idx_j[0]][relevant_features].float()).mean()))
    opt.zero_grad()
  ram_sims = np.float32(
      ram_sims)  #np.histogram(np.float32(ram_sim), 100, (-1, 1))

  # Compute "True" gradient
  grads = [i.detach() * 0 for i in theta_q]
  N = 0
  for samples in replay_buffer.in_order_iterate(ARGS.mbsize * 8):
    loss = loss_func(*samples, theta_q, theta_q).mean()
    loss.backward()
    N += samples[0].shape[0]
    for p, gp in zip(theta_q, grads):
      gp.data.add_(p.grad)
    opt.zero_grad()

  dots = []
  i = 0
  for sample in replay_buffer.in_order_iterate(1):
    loss = loss_func(*sample, theta_q, theta_q).mean()
    loss.backward()
    dots.append(grad_sim(theta_q, grads))
    opt.zero_grad()
    i += 1
  histo = np.histogram(dots, 100, (-1, 1))

  results = {
      "grads": [i.cpu().data.numpy() for i in grads],
      "sims": np.float32(sims),
      "histo": histo,
      "ram_sims": ram_sims,
  }

  path = f'results/grads_{ARGS.checkpoint}.pkl'
  with open(path, "wb") as f:
    pickle.dump(results, f)
Example #7
0
def main():
    device = torch.device(ARGS.device)
    mm.set_device(device)
    results = {
        "measure": [],
        "parameters": [],
    }

    seed = ARGS.run + 1_642_559  # A large prime number
    torch.manual_seed(seed)
    np.random.seed(seed)
    rng = np.random.RandomState(seed)
    env = AtariEnv(ARGS.env_name)
    mbsize = ARGS.mbsize
    Lambda = ARGS.Lambda
    nhid = 32
    num_measure = 1000
    gamma = 0.99
    clone_interval = ARGS.clone_interval
    num_iterations = ARGS.num_iterations

    num_Q_outputs = 1
    # Model
    _Qarch, theta_q, Qf, _Qsemi = mm.build(
        mm.conv2d(4, nhid, 8, stride=4),  # Input is 84x84
        mm.conv2d(nhid, nhid * 2, 4, stride=2),
        mm.conv2d(nhid * 2, nhid * 2, 3),
        mm.flatten(),
        mm.hidden(nhid * 2 * 12 * 12, nhid * 16),
        mm.linear(nhid * 16, num_Q_outputs),
    )
    clone_theta_q = lambda: [i.detach().clone() for i in theta_q]
    theta_target = clone_theta_q()
    opt = make_opt(ARGS.opt, theta_q, ARGS.learning_rate, ARGS.weight_decay)

    # Replay Buffer
    replay_buffer = ReplayBuffer(seed, ARGS.buffer_size)

    # Losses
    td = lambda s, a, r, sp, t, idx, w, tw: sl1(
        r + (1 - t.float()) * gamma * Qf(sp, tw)[:, 0].detach(),
        Qf(s, w)[:, 0],
    )

    tdL = lambda s, a, r, sp, t, idx, w, tw: sl1(
        Qf(s, w)[:, 0], replay_buffer.LG[idx])

    mc = lambda s, a, r, sp, t, idx, w, tw: sl1(
        Qf(s, w)[:, 0], replay_buffer.g[idx])

    # Define metrics
    measure = Measures(
        theta_q, {
            "td": lambda x, w: td(*x, w, theta_target),
            "tdL": lambda x, w: tdL(*x, w, theta_target),
            "mc": lambda x, w: mc(*x, w, theta_target),
        }, replay_buffer, results["measure"], 32)

    # Get expert trajectories
    rand_classes = fill_buffer_with_expert(env, replay_buffer)
    # Compute initial values
    replay_buffer.compute_values(lambda s: Qf(s, theta_q), num_Q_outputs)
    replay_buffer.compute_returns(gamma)
    replay_buffer.compute_reward_distances()
    replay_buffer.compute_episode_boundaries()
    replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q), Lambda,
                                         gamma)

    # Run policy evaluation
    for it in range(num_iterations):
        do_measure = not it % num_measure
        sample = replay_buffer.sample(mbsize)

        if do_measure:
            measure.pre(sample)
        replay_buffer.compute_value_difference(sample, Qf(sample[0], theta_q))

        loss = tdL(*sample, theta_q, theta_target)
        loss = loss.mean()
        loss.backward()
        opt.step()
        opt.zero_grad()

        replay_buffer.update_values(sample, Qf(sample[0], theta_q))
        if do_measure:
            measure.post()

        if it and clone_interval and it % clone_interval == 0:
            theta_target = clone_theta_q()
            replay_buffer.compute_lambda_returns(lambda s: Qf(s, theta_q),
                                                 Lambda, gamma)

        if it and it % clone_interval == 0 or it == num_iterations - 1:
            ps = {str(i): p.data.cpu().numpy() for i, p in enumerate(theta_q)}
            ps.update({"step": it})
            results["parameters"].append(ps)

    with open(f'results/td_lambda_{run}.pkl', 'wb') as f:
        pickle.dump(results, f)
Example #8
0
def main(args):
    device = torch.device(args.device)
    mm.set_device(device)
    results_conn = lmdb.open(f'{args.save_path}/run_{args.run}',
                             map_size=int(16 * 2**30))
    params_conn = lmdb.open(f'{args.save_path}/run_{args.run}/params',
                            map_size=int(16 * 2**30))
    with results_conn.begin(write=True) as txn:
        txn.put(b'args', packobj(args))

    print(args)
    seed = args.run + 1_642_559  # A large prime number
    torch.manual_seed(seed)
    np.random.seed(seed)
    rng = np.random.RandomState(seed)
    env = AtariEnv(args.env_name)
    mbsize = args.mbsize
    nhid = args.nhid
    gamma = 0.99
    env.r_gamma = gamma
    checkpoint_freq = args.checkpoint_freq
    test_freq = args.test_freq
    target_tau = args.target_tau
    target_clone_interval = args.target_clone_interval
    target_type = args.target_type
    num_iterations = args.num_iterations
    num_Q_outputs = env.num_actions
    td_steps = args.td_n_step
    num_env_steps = args.num_env_steps
    measure_drift = args.measure_drift
    # Model

    act = {
        'lrelu': torch.nn.LeakyReLU(),
        'tanh': torch.nn.Tanh(),
        'elu': torch.nn.ELU(),
    }[args.act]
    # Body
    if args.body_type == 'normal':
        body = torch.nn.Sequential(
            torch.nn.Conv2d(4, nhid, 8, stride=4, padding=4), act,
            torch.nn.Conv2d(nhid, nhid * 2, 4, stride=2, padding=2), act,
            torch.nn.Conv2d(nhid * 2, nhid * 2, 3, padding=1), act,
            torch.nn.Flatten(), torch.nn.Linear(nhid * 2 * 12 * 12, nhid * 16),
            act)
    elif args.body_type == 'slim_bot2':
        body = torch.nn.Sequential(
            torch.nn.Conv2d(4, nhid // 2, 8, stride=4, padding=4), act,
            torch.nn.Conv2d(nhid // 2, nhid, 4, stride=2, padding=2), act,
            torch.nn.Conv2d(nhid, nhid * 2, 3, padding=1), act,
            torch.nn.Flatten(), torch.nn.Linear(nhid * 2 * 12 * 12, nhid * 16),
            act)
    elif args.body_type == 'added_bot3':
        body = torch.nn.Sequential(
            torch.nn.Conv2d(4, 4, 3, padding=1), act,
            torch.nn.Conv2d(4, 4, 3, padding=1), act,
            torch.nn.Conv2d(4, 4, 3, padding=1), act,
            torch.nn.Conv2d(4, nhid, 8, stride=4, padding=4), act,
            torch.nn.Conv2d(nhid, nhid * 2, 4, stride=2, padding=2), act,
            torch.nn.Conv2d(nhid * 2, nhid * 2, 3, padding=1), act,
            torch.nn.Flatten(), torch.nn.Linear(nhid * 2 * 12 * 12, nhid * 16),
            act)
    elif args.body_type == 'added_bot3A':
        body = torch.nn.Sequential(
            torch.nn.Conv2d(4, 8, 3, padding=1), act,
            torch.nn.Conv2d(8, 8, 3, padding=1), act,
            torch.nn.Conv2d(8, 8, 3, padding=1), act,
            torch.nn.Conv2d(8, nhid, 8, stride=4, padding=4), act,
            torch.nn.Conv2d(nhid, nhid * 2, 4, stride=2, padding=2), act,
            torch.nn.Conv2d(nhid * 2, nhid * 2, 3, padding=1), act,
            torch.nn.Flatten(), torch.nn.Linear(nhid * 2 * 12 * 12, nhid * 16),
            act)
    elif args.body_type == 'tiny':
        body = torch.nn.Sequential(
            torch.nn.Conv2d(4, nhid, 3, stride=2, padding=1),
            act,  # 42
            torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1),
            act,  # 21
            torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1),
            act,  # 11
            torch.nn.Conv2d(nhid, nhid, 3, stride=2, padding=1),
            act,  # 6
            torch.nn.Conv2d(nhid, num_Q_outputs, 6),  # 1
            torch.nn.Flatten())
    # Head
    if args.head_type == 'normal':
        head = torch.nn.Sequential(torch.nn.Linear(nhid * 16, num_Q_outputs))
    elif args.head_type == 'slim':  # Slim end to we can do block diagonal zeta
        head = torch.nn.Sequential(torch.nn.Linear(nhid * 16, nhid), act,
                                   torch.nn.Linear(nhid, nhid), act,
                                   torch.nn.Linear(nhid, num_Q_outputs))
    elif args.head_type == 'slim2':
        head = torch.nn.Sequential(torch.nn.Linear(nhid * 16, nhid * 2), act,
                                   torch.nn.Linear(nhid * 2, num_Q_outputs))
    elif args.head_type == 'none':
        head = torch.nn.Sequential()

    Qf = torch.nn.Sequential(body, head)

    Qf.to(device)
    Qf.apply(init_weights)
    if args.target_type == 'none':
        Qf_target = Qf
    else:
        Qf_target = copy.deepcopy(Qf)

    opt = make_opt(args, Qf.parameters())
    opt.epsilon = 1e-2
    do_specific_backward = args.opt == 'msgd_corr'

    # Replay Buffer
    replay_buffer = ReplayBufferV2(seed, args.buffer_size)

    ar = lambda x: torch.arange(x.shape[0], device=x.device)

    losses = []
    num_iterations = 1 + num_iterations
    ignore_vprime = bool(args.opt_ignore_vprime)
    total_reward = 0
    last_end = 0
    last_rewards = []
    num_exploration_steps = 50_000
    final_epsilon = 0.05
    recent_states = []
    recent_values = []

    obs = env.reset()
    drift = 0

    # Run policy evaluation
    _prof = (tqdm(range(num_iterations), smoothing=0.001)
             if args.progress else range(num_iterations))
    for it in _prof:

        for eit in range(num_env_steps):
            if it < num_exploration_steps:
                epsilon = 1 - (it / num_exploration_steps) * (1 -
                                                              final_epsilon)
            else:
                epsilon = final_epsilon

            if rng.uniform(0, 1) < epsilon:
                action = rng.randint(0, num_Q_outputs)
            else:
                with torch.no_grad():
                    action = Qf(
                        torch.tensor(obs / 255.0, device=device).unsqueeze(
                            0).float()).argmax().item()

            obsp, r, done, info = env.step(action)
            total_reward += r
            replay_buffer.add(obs, action, r, done, env.enumber % 2)

            obs = obsp
            if done:
                obs = env.reset()
                with results_conn.begin(write=True) as txn:
                    txn.put(
                        f'episode_{env.enumber-1}'.encode(),
                        packobj({
                            "end": it,
                            "start": last_end,
                            "total_reward": total_reward
                        }))
                last_end = it
                last_rewards = [total_reward] + last_rewards[:10]
                if args.progress:
                    _prof.set_description_str(
                        f'reward {int(100*total_reward)}, '
                        f'{int(100*np.mean(last_rewards))}, '
                        f'{drift:.5f}')
                total_reward = 0

        if replay_buffer.current_size < 5000:
            continue

        sample = replay_buffer.sample(mbsize, n_step=td_steps)

        if Qf_target is Qf:
            q = Qf(torch.cat([sample.s, sample.sp], 0))
            v = q[ar(sample.s), sample.a.long()]
            vp = q[ar(sample.sp) + sample.s.shape[0], sample.ap.long()]
        else:
            q = Qf(sample.s)
            v = q[ar(q), sample.a.long()]
            vp = Qf_target(sample.sp)[ar(q), sample.ap.long()]

        gamma_mask = (1 - sample.t.float()) * (gamma**td_steps)
        target = sample.r + gamma_mask * vp
        loss = (v - target.detach()).pow(2)
        if do_specific_backward:
            opt.backward_and_step(v, vp, v - target, gamma_mask)
            #opt.set_predictions(v.mean(), gvp.mean() if not ignore_vprime else None)
        else:
            loss = loss.mean()
            loss.backward()
            opt.step()
            opt.zero_grad()

        losses.append(loss.item())

        if measure_drift:
            recent_states.append((sample.sp, sample.ap.long()))
            recent_values.append(vp.detach())

        if len(recent_states) >= 32:
            rs = torch.cat([i[0] for i in recent_states])
            ra = torch.cat([i[1] for i in recent_states])
            rvp = torch.cat(recent_values)
            with torch.no_grad():
                nvp = Qf_target(rs)[ar(ra), ra]

            drift = abs(rvp - nvp).mean().item()
            with results_conn.begin(write=True) as txn:
                txn.put(f'value_drift_{it}'.encode(), packobj(drift))
            recent_states = []
            recent_values = []

        if target_type == 'frozen' and it % target_clone_interval == 0:
            Qf_target = copy.deepcopy(Qf)
        elif target_type == 'moving':
            for target_param, param in zip(Qf_target.parameters(),
                                           Qf.parameters()):
                target_param.data.mul_(1 - target_tau).add_(param,
                                                            alpha=target_tau)

        if it % checkpoint_freq == 0 and args.save_parameters:
            with params_conn.begin(write=True) as txn:
                txn.put(f'parameters_last'.encode(), packobj(Qf.state_dict()))

        if it % test_freq == 0:
            mc_loss = 0
            n = 0
            with torch.no_grad():
                #print('|W|^2 =', sum([i.pow(2).sum() for i in Qf.parameters()]))
                #for sample in replay_buffer.iterate(512):
                while True:
                    sample = replay_buffer.sample(512)
                    n += sample.s.shape[0]
                    q = Qf(sample.s).max(1).values
                    mc_loss += (q - sample.g).pow(2).sum().item()
                    if n > 10000:
                        break
            with results_conn.begin(write=True) as txn:
                txn.put(f'mc-loss_{it}'.encode(), packobj((mc_loss / n, )))
                if it > 0:
                    txn.put(f'train-loss_{it}'.encode(), packobj(losses))
                losses = []

        if np.isnan(loss.item()):
            print("Learning has diverged, nan loss")
            with results_conn.begin(write=True) as txn:
                txn.put(b'diverged', b'True')
            break
    print("Done.")