def debug(): env = TaxiEnvironment(discrete_state=False) pi = load_taxi_policy_continuous("taxi_data/saved_policies/pi19.npy", env) tau_list = env.generate_roll_out(pi=pi, num_tau=1, tau_len=10) s, a, s_prime, r = tau_list[0] print("policy probs") print(pi.policy_probs) print("") print("policy probs on batch") policy_probs = pi(s) print(policy_probs) a_probs = torch.gather(policy_probs, 1, a.view(-1, 1)).view(-1) print("probs of selected actions:", a_probs) print("mean prob:", a_probs.mean())
def debug(): # set up environment and policies env = TaxiEnvironment() gamma = 0.98 alpha = 0.6 pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy") pi_other = load_taxi_policy("taxi_data/saved_policies/pi3.npy") pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_other, pi_1_weight=alpha) # set up logger init_state_dist_path = "taxi_data/init_state_dist.npy" init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1) init_state_sampler = DiscreteInitStateSampler(init_state_dist) logger = DiscreteQLogger(env=env, pi_e=pi_e, gamma=gamma, tensorboard=True, save_model=True, init_state_sampler=init_state_sampler) # generate train and val data train_data = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=200000, burn_in=100000) val_data = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=200000, burn_in=100000) q = train_q_taxi(env, train_data, val_data, pi_e, pi_b, init_state_sampler, logger, gamma) # calculate final performance policy_val_est = q_estimator(pi_e=pi_e, gamma=gamma, q=q, init_state_sampler=init_state_sampler) policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma, num_tau=1, tau_len=1000000) squared_error = (policy_val_est - policy_val_oracle)**2 print("Test policy val squared error:", squared_error)
def debug(): env = TaxiEnvironment() gamma = 0.98 alpha = 0.6 pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy") pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy") pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha) # set up logger oracle_tau_len = 1000000 logger = DiscreteWLogger(env=env, pi_e=pi_e, pi_b=pi_b, gamma=gamma, tensorboard=False, save_model=False, oracle_tau_len=oracle_tau_len) # generate train, val, and test data tau_len = 200000 burn_in = 100000 train_data = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) val_data = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) # define networks and optimizers w = StateEmbeddingModel(num_s=env.num_s, num_out=1) w_lr = 1e-3 w_optimizer = torch.optim.Adam(w.parameters(), lr=w_lr, betas=(0.5, 0.9)) train_w_network_erm_lbfgs(train_data=train_data, pi_e=pi_e, pi_b=pi_b, w=w, gamma=gamma, val_data=val_data, logger=logger)
def debug(): env = TaxiEnvironment() gamma = 0.98 alpha = 0.6 pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy") pi_other = load_taxi_policy("taxi_data/saved_policies/pi3.npy") pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_other, pi_1_weight=alpha) q = StateEmbeddingModel(num_s=env.num_s, num_out=env.num_a) f = StateEmbeddingModel(num_s=env.num_s, num_out=env.num_a) # generate train and val data tau_list = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=4) s, a, s_prime, r = tau_list[0] print("s:", s) print("a:", a) print("s_prime:", s_prime) print("r:", r) print("") q_obj, f_obj = q_game_objective(q, f, s, a, s_prime, r, pi_e, gamma) print("q_obj:", q_obj) print("f_obj:", f_obj)
def load_eval(): env = TaxiEnvironment(discrete_state=True) gamma = 0.98 alpha = 0.6 # tau_lens = [50000, 100000, 200000, 400000] # // 10000 tau_len = 200000 # args.tau_len burn_in = 0 # 100000 # // 10000 init_state_dist_path = "taxi_data/init_state_dist.npy" pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy") pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy") pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha) init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1) init_state_sampler = DiscreteInitStateSampler(init_state_dist) policy_val_oracle = -0.74118131399 print('policy_val_oracle', policy_val_oracle) print('args.seed', args.seed) np.random.seed(args.seed) torch.random.manual_seed(args.seed) random.seed(args.seed) train_data = env.generate_roll_out( pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) train_data_loader = train_data.get_data_loader(1024) print('train_data generated') w_path = 'logs/2020-09-13T18:33:35.909858_w_2_200000/290_w_.pt' q_path = 'logs/2020-09-11T11:04:53.345691_q_2_200000/790_q_.pt' q_table = torch.zeros((env.num_s, env.num_a)) q = QTableModel(q_table) q.model.load_state_dict(torch.load(q_path)) w = StateEmbeddingModel(num_s=env.num_s, num_out=1) w.load_state_dict(torch.load(w_path)) print('[ours] w_oracle and q loaded!') double_est = double_estimator( train_data_loader, pi_e, pi_b, w, q, gamma, init_state_sampler) print('[ours] drl_est', double_est)
def debug(): env = TaxiEnvironment(discrete_state=True) gamma = 0.98 alpha = 0.6 # tau_lens = [50000, 100000, 200000, 400000] # // 10000 tau_len = args.tau_len burn_in = 0 # 100000 # // 10000 init_state_dist_path = "taxi_data/init_state_dist.npy" pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy") pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy") pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha) init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1) init_state_sampler = DiscreteInitStateSampler(init_state_dist) policy_val_oracle = -0.74118131399 print('policy_val_oracle', policy_val_oracle) # for j in tau_lens: # tau_len = j # preds = [] # for i in range(100): # print(j, i) for i in range(3): args.seed = i print('args.seed', args.seed) np.random.seed(args.seed) torch.random.manual_seed(args.seed) random.seed(args.seed) train_data = env.generate_roll_out( pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) train_data_loader = train_data.get_data_loader(1024) val_data = None print('train_data generated') now = datetime.datetime.now() if args.q: q_path = os.path.join('logs', '_'.join( [str(now.isoformat()), 'q', str(args.seed), str(args.tau_len)])) q_logger = SimplestQLogger( env, pi_e, gamma, init_state_sampler, True, True, q_path, policy_val_oracle) q = train_q_taxi(env, train_data, val_data, pi_e, pi_b, init_state_sampler, q_logger, gamma, ERM_epoch=args.q_ERM_epoch, epoch=args.q_epoch) q_est = q_estimator(pi_e, gamma, q, init_state_sampler) # SASR = train_data.restore_strcture(discrete=False) # dataloader doesn't preserve trajectory structure, but SASR does. # drl1 = double_importance_sampling_estimator(SASR, pi_b, pi_e, args.gamma, q, split_shape=[4, 1, 4, 1]) print('[ours] q_est', q_est) with open(os.path.join(q_path, 'results.txt'), 'w') as f: f.write(str(q_est)) print('q_est written in', q_path) if args.w: w_path = os.path.join('logs', '_'.join( [str(now.isoformat()), 'w', str(args.seed), str(args.tau_len)])) w_logger = SimplestWLogger( env, pi_e, pi_b, gamma, True, True, w_path, policy_val_oracle) w = train_w_taxi(env, train_data, val_data, pi_e, pi_b, init_state_sampler, w_logger, gamma, ERM_epoch=args.w_ERM_epoch, epoch=args.w_epoch) w_est = w_estimator(train_data_loader, pi_e, pi_b, w) print('[ours] w_est', w_est) with open(os.path.join(w_path, 'results.txt'), 'w') as f: f.write(str(w_est)) print('w_est written in', w_path) if args.w and args.q: double_est = double_estimator( train_data_loader, pi_e, pi_b, w, q, gamma, init_state_sampler) print('[ours] drl_est', double_est) with open(os.path.join(q_path, 'results.txt'), 'a') as f: f.write(str(double_est)) print('double_est written in', q_path)
def debug(): env = TaxiEnvironment(discrete_state=True) gamma = 0.98 alpha = 0.6 tau_lens = [50000, 100000, 200000, 400000] # // 10000 burn_in = 0 # 100000 # // 10000 init_state_dist_path = "taxi_data/init_state_dist.npy" pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy") pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy") pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha) init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1) init_state_sampler = DiscreteInitStateSampler(init_state_dist) policy_val_oracle = -0.74118131399 print('policy_val_oracle', policy_val_oracle) for j in tau_lens: tau_len = j preds = [] for i in range(100): print(j, i) np.random.seed(i) torch.random.manual_seed(i) random.seed(i) train_data = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) train_data_loader = train_data.get_data_loader(1024) SASR = train_data.restore_strcture(discrete=True) # print('finished data generation.') # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma, # num_tau=1, tau_len=10000000) is_est = importance_sampling_estimator(SASR, pi_b, pi_e, gamma) is_ess = importance_sampling_estimator_stepwise( SASR, pi_b, pi_e, gamma) # print('is_est', is_est, 'is_ess', is_ess) # # Masa's w method, only for discrete den_discrete = Density_Ratio_discounted(env.n_state, gamma) _, w_table = train_density_ratio(SASR, pi_b, pi_e, den_discrete, gamma) w_table = w_table.reshape(-1) w = StateEmbeddingModel(num_s=env.num_s, num_out=1) w.set_weights(torch.from_numpy(w_table)) # print('[Masa] fitted w') # # Masa's q method is the same as fit_q_tabular q = fit_q_tabular(train_data, pi_e, gamma) # print('[Masa] fitted q') q_est = q_estimator(pi_e, gamma, q, init_state_sampler) drl_est = double_estimator(train_data_loader, pi_e, pi_b, w, q, gamma, init_state_sampler) w_est = w_estimator(train_data_loader, pi_e, pi_b, w) # print('[Masa] q_est', q_est, '[Masa] w_est', # w_est, '[Masa] drl_est', drl_est) preds.append([is_est, is_ess, q_est, drl_est, w_est]) preds = np.array(preds) errors = (preds - policy_val_oracle)**2 mse = np.mean(errors, axis=0) print('[is_est, is_ess, q_est, drl_est, w_est] \n', preds) print('MSE for [is_est, is_ess, q_est, drl_est, w_est] \n', mse) np.save('estimators/masa_preds_' + str(j), preds) np.save('estimators/masa_mse_' + str(j), mse)
def debug(): # set up environment and policies env = TaxiEnvironment(discrete_state=True) # env = CartpoleEnvironment() gamma = 0.98 alpha = 0.6 temp = 2.0 hidden_dim = 50 state_dim = 4 pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy") pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy") pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha) # pi_e = load_taxi_policy_continuous( # "taxi_data/saved_policies/pi19.npy", env) # pi_s = load_taxi_policy_continuous("taxi_data/saved_policies/pi3.npy", env) # pi_b = GenericMixturePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha) # set up logger oracle_tau_len = 100000 # // 10000 init_state_dist_path = "taxi_data/init_state_dist.npy" init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1) init_state_sampler = DiscreteInitStateSampler(init_state_dist) # init_state_sampler = DecodingDiscreteInitStateSampler(init_state_dist, # env.decode_state) # init_state_sampler = CartpoleInitStateSampler(env) # logger = SimpleDiscretePrintWLogger(env=env, pi_e=pi_e, pi_b=pi_b, # gamma=gamma, # oracle_tau_len=oracle_tau_len) logger = DiscreteWLogger(env=env, pi_e=pi_e, pi_b=pi_b, gamma=gamma, tensorboard=True, save_model=True, oracle_tau_len=oracle_tau_len) # generate train, val, and test data tau_len = 200000 # // 10000 burn_in = 100000 # // 10000 train_data = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) print('finish train') val_data = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) print('finish val') test_data = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) print('finish test') w = train_w_taxi(env, train_data, val_data, pi_e, pi_b, init_state_sampler, logger, gamma) # calculate final performance test_data_loader = test_data.get_data_loader(1024) policy_val_est = w_estimator(tau_list_data_loader=test_data_loader, pi_e=pi_e, pi_b=pi_b, w=w) policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma, num_tau=1, tau_len=1000000) squared_error = (policy_val_est - policy_val_oracle)**2 print("Test policy val squared error:", squared_error)