def debug(): from environments.cartpole_environment import CartpoleEnvironment from policies.mixture_policies import GenericMixturePolicy from policies.cartpole_policies import load_cartpole_policy gamma = 0.98 alpha = 0.6 temp = 2.0 hidden_dim = 50 pi_other_name = 'cartpole_180_-99900.0' load = False est_policy = True env = CartpoleEnvironment() pi_e = load_cartpole_policy("cartpole_weights/cartpole_best.pt", temp, env.state_dim, hidden_dim, env.num_a) pi_other = load_cartpole_policy("cartpole_weights/"+pi_other_name+".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_b = GenericMixturePolicy(pi_e, pi_other, alpha) if load: w = load_continuous_w_oracle( env, hidden_dim, './continuous_w_oracle.pt') else: w = calculate_continuous_w_oracle( env, pi_b, pi_e, gamma, hidden_dim, prefix=pi_other_name+"_train_", load=True, epoch=10) if est_policy: from dataset.tau_list_dataset import TauListDataset from estimators.infinite_horizon_estimators import oracle_w_estimator from estimators.benchmark_estimators import on_policy_estimate tau_e_path = 'tau_e_cartpole/' tau_b_path = 'tau_b_cartpole/' tau_len = 200000 burn_in = 100000 test_data = TauListDataset.load(tau_b_path, pi_other_name+'_test_') test_data_pi_e = TauListDataset.load(tau_e_path, prefix="gamma_") test_data_loader = test_data.get_data_loader(1024) policy_val_estimate = w_estimator(tau_list_data_loader=test_data_loader, pi_e=pi_e, pi_b=pi_b, w=w) # policy_val_estimate = oracle_w_estimator( # tau_list_data_loader=test_data_loader, pi_e=pi_e, # pi_b=pi_b, w_oracle=w) # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma, # num_tau=1, tau_len=1000000, # load_path=tau_e_path) policy_val_oracle = float(test_data_pi_e.r.mean().detach()) squared_error = (policy_val_estimate - policy_val_oracle) ** 2 print('W orcacle estimates & true policy value ', policy_val_estimate, policy_val_oracle) print("Test policy val squared error:", squared_error)
def debug(): torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) env = CartpoleEnvironment(reward_reshape=args.reward_reshape) now = datetime.datetime.now() timestamp = now.isoformat() log_path = '{}/{}/'.format( args.save_folder, '_'.join([timestamp] + [i.replace("--", "") for i in sys.argv[1:]] + [args.script])) print(log_path) # set up environment and policies pi_e = load_cartpole_policy( os.path.join(args.cartpole_weights, args.pi_e_name + ".pt"), args.temp, env.state_dim, args.hidden_dim, env.num_a) pi_other = load_cartpole_policy( os.path.join(args.cartpole_weights, args.pi_other_name + ".pt"), args.temp, env.state_dim, args.hidden_dim, env.num_a) pi_b = GenericMixturePolicy(pi_e, pi_other, args.alpha) init_state_sampler = CartpoleInitStateSampler(env) now = datetime.datetime.now() if not args.rollout_dataset: print('Loading datasets for training...') train_data = TauListDataset.load(args.tau_b_path, prefix=args.pi_other_name + '_train_') val_data = TauListDataset.load(args.tau_b_path, prefix=args.pi_other_name + '_val_') pi_e_data_discounted = TauListDataset.load( args.tau_e_path, prefix=args.pi_e_name + '_gamma' + str(args.gamma) + '_') if args.oracle_val is None: args.oracle_val = float(pi_e_data_discounted.r.mean()) else: # generate train, val, and test data, very slow so load exisiting data is preferred. # pi_b data is for training so no gamma. print('Not loading pi_b data, so generating') train_data = env.generate_roll_out(pi=pi_b, num_tau=args.num_tau, tau_len=args.tau_len, burn_in=args.burn_in) print('Finished generating train data of', args.pi_other_name) val_data = env.generate_roll_out(pi=pi_b, num_tau=args.num_tau, tau_len=args.tau_len, burn_in=args.burn_in) print('Finished generating val data of', args.pi_other_name) train_data.save(args.tau_b_path, prefix=args.pi_other_name + '_train_') val_data.save(args.tau_b_path, prefix=args.pi_other_name + '_val_') # pi_e data with gamma is only for calculating oracle policy value, so has the gamma. print('Not loarding pi_e data, so generating') pi_e_data_discounted = env.generate_roll_out( pi=pi_e, num_tau=args.num_tau, tau_len=args.oracle_tau_len, burn_in=args.burn_in, gamma=args.gamma) print('Finished generating data of pi_e with gamma') pi_e_data_discounted.save(args.tau_e_path, prefix=args.pi_e_name + '_gamma' + str(args.gamma) + '_') args.oracle_val = float(pi_e_data_discounted.r.mean()) print('Oracle policy val', args.oracle_val) q_oracle = QOracleModel.load_continuous_q_oracle(env, args.hidden_dim, env.num_a, args.oracle_path) logger = ContinuousQLogger(env=env, pi_e=pi_e, gamma=args.gamma, tensorboard=args.no_tensorboard, save_model=args.no_save_model, init_state_sampler=init_state_sampler, log_path=log_path, policy_val_oracle=args.oracle_val, q_oracle=q_oracle, pi_e_data_discounted=pi_e_data_discounted) # logger = SimplestQLogger( # env, pi_e, args.gamma, init_state_sampler, True, True, log_path, args.oracle_val) with open(os.path.join(log_path, 'meta.txt'), 'w') as f: print(args, file=f) q = train_q_network_cartpole(args, env, train_data, val_data, pi_e, pi_b, init_state_sampler, logger) # calculate final performance, optional policy_val_est = q_estimator(pi_e=pi_e, gamma=args.gamma, q=q, init_state_sampler=init_state_sampler) squared_error = (policy_val_est - logger.policy_val_oracle)**2 print('Policy_val_oracle', logger.policy_val_oracle) print("Test policy val squared error:", squared_error)
def debug_oracle(): from environments.cartpole_environment import CartpoleEnvironment from policies.mixture_policies import GenericMixturePolicy from policies.cartpole_policies import load_cartpole_policy gamma = 0.98 alpha = 0.6 temp = 2.0 hidden_dim = 50 pi_e_name = '456_-99501_cartpole_best' pi_other_name = '400_-99761_cartpole' load = False est_policy = True env = CartpoleEnvironment(reward_reshape=False) init_state_sampler = CartpoleInitStateSampler(env) pi_e = load_cartpole_policy("cartpole_weights_1/" + pi_e_name + ".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_other = load_cartpole_policy( "cartpole_weights_1/" + pi_other_name + ".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_b = GenericMixturePolicy(pi_e, pi_other, alpha) if load: w = load_continuous_w_oracle(env, hidden_dim, './continuous_w_oracle.pt') else: w = calculate_continuous_w_oracle( env, pi_b, pi_e, gamma, hidden_dim, prefix=[ pi_e_name + "_gamma" + str(gamma) + '_', pi_other_name + "_train_" ], epoch=300) # load=False, num_tau=20, tau_len=100000, burn_in=1000, if est_policy: from estimators.infinite_horizon_estimators import oracle_w_estimator from estimators.benchmark_estimators import on_policy_estimate tau_e_path = 'tau_e_cartpole/' tau_b_path = 'tau_b_cartpole/' tau_len = 200000 burn_in = 100000 test_data = TauListDataset.load(tau_b_path, pi_other_name + '_train_') test_data_pi_e = TauListDataset.load(tau_e_path, prefix=pi_e_name + "_gamma" + str(gamma) + '_') test_data_loader = test_data.get_data_loader(1024) policy_val_estimate = w_estimator( tau_list_data_loader=test_data_loader, pi_e=pi_e, pi_b=pi_b, w=w) # policy_val_estimate = oracle_w_estimator( # tau_list_data_loader=test_data_loader, pi_e=pi_e, # pi_b=pi_b, w_oracle=w) # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma, # num_tau=1, tau_len=1000000, # load_path=tau_e_path) policy_val_oracle = float(test_data_pi_e.r.mean().detach()) squared_error = (policy_val_estimate - policy_val_oracle)**2 print('W orcacle estimates & true policy value ', policy_val_estimate, policy_val_oracle) print("Test policy val squared error:", squared_error)
def debug_RKHS_method(): from environments.cartpole_environment import CartpoleEnvironment from policies.mixture_policies import GenericMixturePolicy from policies.cartpole_policies import load_cartpole_policy gamma = 0.98 alpha = 0.6 temp = 2.0 hidden_dim = 50 pi_e_name = '456_-99501_cartpole_best' pi_other_name = '400_-99761_cartpole' load = True est_policy = False load_path = 'tau_b_cartpole/' prefix = pi_other_name + '_train_' env = CartpoleEnvironment(reward_reshape=False) init_state_sampler = CartpoleInitStateSampler(env) pi_e = load_cartpole_policy("cartpole_weights_1/" + pi_e_name + ".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_other = load_cartpole_policy( "cartpole_weights_1/" + pi_other_name + ".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_b = GenericMixturePolicy(pi_e, pi_other, alpha) if load: # s_b = torch.load(open(os.path.join(load_path, prefix+'s.pt'), 'rb')) # a_b = torch.load(open(os.path.join(load_path, prefix+'a.pt'), 'rb')) # s_prime_b = torch.load( # open(os.path.join(load_path, prefix+'s_prime.pt'), 'rb')) tau_list_b = TauListDataset.load(load_path, prefix) else: tau_list_b = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) # TODO: figure out current file and path naming rule tau_list_b.save(load_path) # s_b = tau_list_b[:][0] # # a_b = tau_list_b[:][1] # s_prime_b = tau_list_b[:][2] tau_e_path = 'tau_e_cartpole/' test_data_pi_e = TauListDataset.load(tau_e_path, prefix=pi_e_name + "_gamma" + str(gamma) + '_') policy_val_oracle = float(test_data_pi_e.r.mean().detach()) # logger = ContinuousWLogger(env, pi_e, pi_b, gamma, # False, False, None, policy_val_oracle) logger = SimplestWLogger(env, pi_e, pi_b, gamma, False, False, None, policy_val_oracle) # if load: # w = load_continuous_w_oracle( # env, hidden_dim, './continuous_w_oracle.pt') # else: # w = calculate_continuous_w_oracle( # env, pi_b, pi_e, gamma, hidden_dim, prefix=pi_other_name+"_train_", # load=True, epoch=10) w = RKHS_method(env, pi_b, pi_e, tau_list_b, gamma, hidden_dim, init_state_sampler, gaussian_kernel, logger=logger)
def debug(): env = CartpoleEnvironment(reward_reshape=args.reward_reshape) # [50000, 100000, 200000, 400000] // 10000 tau_lens = [400000] burn_in = 100000 # // 10000 pi_e = load_cartpole_policy( os.path.join(args.cartpole_weights, args.pi_e_name + ".pt"), args.temp, env.state_dim, args.hidden_dim, env.num_a) pi_other = load_cartpole_policy( os.path.join(args.cartpole_weights, args.pi_other_name + ".pt"), args.temp, env.state_dim, args.hidden_dim, env.num_a) pi_b = GenericMixturePolicy(pi_e, pi_other, args.alpha) init_state_sampler = CartpoleInitStateSampler(env) print('policy_val_oracle', args.oracle_val) # combinations=['is_est', 'is_ess', 'q_ERM', 'w_RKHS', 'drl_q_ERM_w_RKHS', 'q_GMM', 'w_GMM', 'drl_q_GMM_w_RKHS', 'drl_q_ERM_w_GMM', 'drl_q_GMM_w_GMM'] for j in tau_lens: tau_len = j preds = [] for i in range(10, end): print(j, i) np.random.seed(i) torch.random.manual_seed(i) random.seed(i) train_data = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=args.burn_in) # train_data = TauListDataset.load( # args.tau_b_path, prefix=args.pi_other_name+'_train_') train_data_loader = train_data.get_data_loader(args.batch_size) SASR = train_data.restore_strcture(discrete=False) print('finished data generation.') # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma, num_tau=1, tau_len=10000000) is_est = importance_sampling_estimator(SASR, pi_b, pi_e, args.gamma, split_shape=[4, 1, 4, 1]) is_ess = importance_sampling_estimator_stepwise( SASR, pi_b, pi_e, args.gamma, split_shape=[4, 1, 4, 1]) print('is_est', is_est, 'is_ess', is_ess) # RKHS for w now = datetime.datetime.now() timestamp = now.isoformat() log_path = '{}/{}/'.format( args.save_folder, '_'.join([timestamp] + [i.replace("--", "") for i in sys.argv[1:]] + [args.script + '_W_' + str(j) + str(i)])) w_logger = SimplestWLogger(env, pi_e, pi_b, args.gamma, True, True, log_path, args.oracle_val) # w_logger = None w_rkhs = RKHS_method( env, pi_b, pi_e, train_data, args.gamma, args.hidden_dim, init_state_sampler, gaussian_kernel, logger=w_logger ) # , batch_size=args.batch_size, epoch=args.w_RKHS_epoch, lr=args.w_rkhs_lr) w_rkhs_est = w_estimator(train_data_loader, pi_e, pi_b, w_rkhs) print('w_rkhs_est', w_rkhs_est) # ERM now = datetime.datetime.now() timestamp = now.isoformat() log_path = '{}/{}/'.format( args.save_folder, '_'.join([timestamp] + [i.replace("--", "") for i in sys.argv[1:]] + [args.script + '_Q_' + str(j) + str(i)])) q_logger = SimplestQLogger(env, pi_e, args.gamma, init_state_sampler, True, True, log_path, args.oracle_val) # q_logger = None q_erm = QNetworkModelSimple(env.state_dim, args.hidden_dim, out_dim=env.num_a, neg_output=True) q_erm_optimizer = OAdam(q_erm.parameters(), lr=args.q_lr, betas=(0.5, 0.9)) train_q_network_erm(train_data, pi_e, args.q_GMM_epoch + args.q_ERM_epoch, args.batch_size, q_erm, q_erm_optimizer, args.gamma, val_data=None, val_freq=10, q_scheduler=None, logger=q_logger) q_erm_est = q_estimator(pi_e, args.gamma, q_erm, init_state_sampler) print('q_erm_est', q_erm_est) # q GMM q_GMM = train_q_network_cartpole(args, env, train_data, None, pi_e, pi_b, init_state_sampler, q_logger) q_GMM_est = q_estimator(pi_e, args.gamma, q_GMM, init_state_sampler) print('q_GMM_est', q_GMM_est) # w GMM w_GMM = train_w_network_cartpole(args, env, train_data, None, pi_e, pi_b, init_state_sampler, w_logger) w_GMM_est = w_estimator(train_data_loader, pi_e, pi_b, w_GMM) print('w_GMM_est', w_GMM_est) drl_q_ERM_w_RKHS = double_estimator(train_data_loader, pi_e, pi_b, w_rkhs, q_erm, args.gamma, init_state_sampler) drl_q_ERM_w_GMM = double_estimator(train_data_loader, pi_e, pi_b, w_GMM, q_erm, args.gamma, init_state_sampler) drl_q_GMM_w_RKHS = double_estimator(train_data_loader, pi_e, pi_b, w_rkhs, q_GMM, args.gamma, init_state_sampler) drl_q_GMM_w_GMM = double_estimator(train_data_loader, pi_e, pi_b, w_GMM, q_GMM, args.gamma, init_state_sampler) print('drl_q_ERM_w_RKHS', drl_q_ERM_w_RKHS, 'drl_q_ERM_w_GMM', drl_q_ERM_w_GMM, 'drl_q_GMM_w_RKHS', drl_q_GMM_w_RKHS, 'drl_q_GMM_w_GMM', drl_q_GMM_w_GMM) preds.append([ is_est, is_ess, w_rkhs_est, q_erm_est, q_GMM_est, w_GMM_est, drl_q_ERM_w_RKHS, drl_q_ERM_w_GMM, drl_q_GMM_w_RKHS, drl_q_GMM_w_GMM ]) print( '[is_est, is_ess, w_rkhs_est, q_erm_est, q_GMM_est, w_GMM_est, drl_q_ERM_w_RKHS, drl_q_ERM_w_GMM, drl_q_GMM_w_RKHS, drl_q_GMM_w_GMM] \n', preds) preds = np.array(preds) errors = (preds - args.oracle_val)**2 mse = np.mean(errors, axis=0) print( 'MSE for [is_est, is_ess, w_rkhs_est, q_erm_est, q_GMM_est, w_GMM_est, drl_q_ERM_w_RKHS, drl_q_ERM_w_GMM, drl_q_GMM_w_RKHS, drl_q_GMM_w_GMM] \n', mse) np.save('estimators/benchmarks_GMM_cont_' + str(j) + str(end), preds) np.save('estimators/benchmarks_GMM_cont_' + str(j) + str(end), mse)