def debug(): from environments.cartpole_environment import CartpoleEnvironment from policies.mixture_policies import GenericMixturePolicy from policies.cartpole_policies import load_cartpole_policy gamma = 0.98 alpha = 0.6 temp = 2.0 hidden_dim = 50 pi_other_name = 'cartpole_180_-99900.0' load = False est_policy = True env = CartpoleEnvironment() pi_e = load_cartpole_policy("cartpole_weights/cartpole_best.pt", temp, env.state_dim, hidden_dim, env.num_a) pi_other = load_cartpole_policy("cartpole_weights/"+pi_other_name+".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_b = GenericMixturePolicy(pi_e, pi_other, alpha) if load: w = load_continuous_w_oracle( env, hidden_dim, './continuous_w_oracle.pt') else: w = calculate_continuous_w_oracle( env, pi_b, pi_e, gamma, hidden_dim, prefix=pi_other_name+"_train_", load=True, epoch=10) if est_policy: from dataset.tau_list_dataset import TauListDataset from estimators.infinite_horizon_estimators import oracle_w_estimator from estimators.benchmark_estimators import on_policy_estimate tau_e_path = 'tau_e_cartpole/' tau_b_path = 'tau_b_cartpole/' tau_len = 200000 burn_in = 100000 test_data = TauListDataset.load(tau_b_path, pi_other_name+'_test_') test_data_pi_e = TauListDataset.load(tau_e_path, prefix="gamma_") test_data_loader = test_data.get_data_loader(1024) policy_val_estimate = w_estimator(tau_list_data_loader=test_data_loader, pi_e=pi_e, pi_b=pi_b, w=w) # policy_val_estimate = oracle_w_estimator( # tau_list_data_loader=test_data_loader, pi_e=pi_e, # pi_b=pi_b, w_oracle=w) # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma, # num_tau=1, tau_len=1000000, # load_path=tau_e_path) policy_val_oracle = float(test_data_pi_e.r.mean().detach()) squared_error = (policy_val_estimate - policy_val_oracle) ** 2 print('W orcacle estimates & true policy value ', policy_val_estimate, policy_val_oracle) print("Test policy val squared error:", squared_error)
def debug(): env = TaxiEnvironment() pi = RandomPolicy(num_a=6, state_rank=0) dataset = env.generate_roll_out(pi=pi, num_tau=2, tau_len=20) print(dataset.s) print(dataset.lens) dataset.save("tmp") dataset_loaded = TauListDataset.load("tmp") print(dataset_loaded.s) print(dataset_loaded.lens)
def generate_roll_out(self, pi, num_tau, tau_len, burn_in=0, gamma=None): """ Generates roll out (list of trajectories, each of which is a tuple (s, a, s', r) of pytorch arrays :param pi: policy for rollout, semantics are for any state s, pi(s) should return array of action probabilities :param num_tau: number of trajectories to roll out :param tau_len: length of each roll out trajectory :param burn_in: number of actions to perform at the start of each trajectory before we begin logging :param gamma: (optional) if provided, at each time step with probability (1 - gamma) we reset to an initial state :return: """ self.tau_list = [] for _ in range(num_tau): states = [] actions = [] rewards = [] successor_states = [] for i in range(tau_len + burn_in): if i == 0 or (gamma and np.random.rand() > gamma): s = self.reset() else: s = successor_states[-1] p = np.array(pi(s)) p = p / p.sum() a = np.random.choice(list(range(self.num_a)), p=p) s_prime, r, done = self.step(a) states.append(s) actions.append(a) rewards.append(r) successor_states.append(s_prime) if self.is_s_discrete: s_tensor = torch.LongTensor(states[burn_in:]) ss_tensor = torch.LongTensor(successor_states[burn_in:]) else: s_tensor = torch.stack(states[burn_in:]) ss_tensor = torch.stack(successor_states[burn_in:]) a_tensor = torch.LongTensor(actions[burn_in:]) r_tensor = torch.FloatTensor(rewards[burn_in:]) self.tau_list.append((s_tensor, a_tensor, ss_tensor, r_tensor)) return TauListDataset(self.tau_list)
def debug(): torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) env = CartpoleEnvironment(reward_reshape=args.reward_reshape) now = datetime.datetime.now() timestamp = now.isoformat() log_path = '{}/{}/'.format( args.save_folder, '_'.join([timestamp] + [i.replace("--", "") for i in sys.argv[1:]] + [args.script])) print(log_path) # set up environment and policies pi_e = load_cartpole_policy( os.path.join(args.cartpole_weights, args.pi_e_name + ".pt"), args.temp, env.state_dim, args.hidden_dim, env.num_a) pi_other = load_cartpole_policy( os.path.join(args.cartpole_weights, args.pi_other_name + ".pt"), args.temp, env.state_dim, args.hidden_dim, env.num_a) pi_b = GenericMixturePolicy(pi_e, pi_other, args.alpha) init_state_sampler = CartpoleInitStateSampler(env) now = datetime.datetime.now() if not args.rollout_dataset: print('Loading datasets for training...') train_data = TauListDataset.load(args.tau_b_path, prefix=args.pi_other_name + '_train_') val_data = TauListDataset.load(args.tau_b_path, prefix=args.pi_other_name + '_val_') pi_e_data_discounted = TauListDataset.load( args.tau_e_path, prefix=args.pi_e_name + '_gamma' + str(args.gamma) + '_') if args.oracle_val is None: args.oracle_val = float(pi_e_data_discounted.r.mean()) else: # generate train, val, and test data, very slow so load exisiting data is preferred. # pi_b data is for training so no gamma. print('Not loading pi_b data, so generating') train_data = env.generate_roll_out(pi=pi_b, num_tau=args.num_tau, tau_len=args.tau_len, burn_in=args.burn_in) print('Finished generating train data of', args.pi_other_name) val_data = env.generate_roll_out(pi=pi_b, num_tau=args.num_tau, tau_len=args.tau_len, burn_in=args.burn_in) print('Finished generating val data of', args.pi_other_name) train_data.save(args.tau_b_path, prefix=args.pi_other_name + '_train_') val_data.save(args.tau_b_path, prefix=args.pi_other_name + '_val_') # pi_e data with gamma is only for calculating oracle policy value, so has the gamma. print('Not loarding pi_e data, so generating') pi_e_data_discounted = env.generate_roll_out( pi=pi_e, num_tau=args.num_tau, tau_len=args.oracle_tau_len, burn_in=args.burn_in, gamma=args.gamma) print('Finished generating data of pi_e with gamma') pi_e_data_discounted.save(args.tau_e_path, prefix=args.pi_e_name + '_gamma' + str(args.gamma) + '_') args.oracle_val = float(pi_e_data_discounted.r.mean()) print('Oracle policy val', args.oracle_val) q_oracle = QOracleModel.load_continuous_q_oracle(env, args.hidden_dim, env.num_a, args.oracle_path) logger = ContinuousQLogger(env=env, pi_e=pi_e, gamma=args.gamma, tensorboard=args.no_tensorboard, save_model=args.no_save_model, init_state_sampler=init_state_sampler, log_path=log_path, policy_val_oracle=args.oracle_val, q_oracle=q_oracle, pi_e_data_discounted=pi_e_data_discounted) # logger = SimplestQLogger( # env, pi_e, args.gamma, init_state_sampler, True, True, log_path, args.oracle_val) with open(os.path.join(log_path, 'meta.txt'), 'w') as f: print(args, file=f) q = train_q_network_cartpole(args, env, train_data, val_data, pi_e, pi_b, init_state_sampler, logger) # calculate final performance, optional policy_val_est = q_estimator(pi_e=pi_e, gamma=args.gamma, q=q, init_state_sampler=init_state_sampler) squared_error = (policy_val_est - logger.policy_val_oracle)**2 print('Policy_val_oracle', logger.policy_val_oracle) print("Test policy val squared error:", squared_error)
def calculate_continuous_w_oracle(env, pi_b, pi_e, gamma, hidden_dim, cuda=False, lr=1e-3, tau_len=1000000, burn_in=100000, batch_size=1024, epoch=100, num_tau=1, load=True, load_path=('tau_e_cartpole/', 'tau_b_cartpole/'), prefix=''): """ :param env: environment (should be AbstractEnvironment) :param pi_b: behavior policy (should be from policies module) :param pi_e: evaluation policy (should be from policies module) :param gamma: discount factor :param num_s: number of different states :param tau_len: length to trajectory to use for monte-carlo estimate :param burn_in: burn-in period for monte-carlo sampling :return w: Network with same architecture as the trainining model Observation: Type: Box(4) Num Observation Min Max 0 Cart Position -4.8 4.8 1 Cart Velocity -Inf Inf 2 Pole Angle -24° 24° 3 Pole Velocity At Tip -Inf Inf Action: Type: Discrete(2) Num Action 0 Push cart to the left 1 Push cart to the right """ if load: # The stored s_e is much longer than s_b since s_e was used only for estimating policy value. tau_list_b = TauListDataset.load(load_path[1], prefix[1]) s_b = tau_list_b.s tau_list_e = TauListDataset.load(load_path[0], prefix[0]) s_e = tau_list_e.s[:s_b.size(0)] else: tau_list_e = env.generate_roll_out(pi=pi_e, num_tau=num_tau, tau_len=tau_len, gamma=gamma, burn_in=burn_in) tau_list_b = env.generate_roll_out(pi=pi_b, num_tau=num_tau, tau_len=tau_len, burn_in=burn_in) # tau_list_e.save(load_path[0]) # tau_list_b.save(load_path[1]) s_e = tau_list_e.s # [:][0] s_b = tau_list_b.s # [:][0] x = StateClassifierModel(env.state_dim, 128, out_dim=2) train_data_loader, val_data_loader, val_len = create_oracle_datasets( s_e, s_b, batch_size, train_ratio=0.8) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(list(x.model.parameters()), lr=lr) optimizer_lbfgs = optim.LBFGS(x.model.parameters()) lowest_loss = np.inf for i in range(epoch): print(i) x.train() loss_sum_train = 0.0 norm_train = 0.0 num_correct_sum_train = 0 for batch_idx, (data, labels) in enumerate(train_data_loader): if cuda: data, labels = data.cuda(), labels.cuda() logits = x(data) loss = criterion(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() prob = logits.softmax(-1) num_correct = float((prob.argmax(-1) == labels).sum()) num_correct_sum_train += num_correct loss_sum_train += (float(loss.detach()) * len(data)) norm_train += len(data) train_acc = num_correct_sum_train / norm_train avg_loss = loss_sum_train / norm_train print('train accuracy ', train_acc) print('avg loss ', avg_loss) if i % 1 == 0: x.eval() num_correct_sum = 0.0 loss_sum = 0.0 norm = 0.0 for batch_idx, (data, labels) in enumerate(val_data_loader): if cuda: data, labels = data.cuda(), labels.cuda() logits = x(data) prob = logits.softmax(-1) loss = criterion(logits, labels) num_correct = float((prob.argmax(-1) == labels).sum()) num_correct_sum += num_correct loss_sum += (float(loss.detach()) * len(data)) norm += len(data) # import pdb # pdb.set_trace() if batch_idx == 0: print(prob[:50], labels[:50]) test_acc = num_correct_sum / norm avg_loss = loss_sum / norm print('test accuracy ', test_acc) print('avg loss ', avg_loss) if avg_loss < lowest_loss: lowest_loss = avg_loss torch.save(x.state_dict(), './continuous_w_oracle.pt') print('best model saved with loss ', lowest_loss) train_data = train_data_loader.dataset[:] def closure_(): optimizer_lbfgs.zero_grad() s_ = train_data[0] labels_ = train_data[1] logits_ = x(s_) loss_ = criterion(logits_, labels_) loss_.backward() return loss_ optimizer_lbfgs.step(closure_) return WOracleModel(state_classifier=x, reg=1e-7, train_data_loader=train_data_loader)
def debug_oracle(): from environments.cartpole_environment import CartpoleEnvironment from policies.mixture_policies import GenericMixturePolicy from policies.cartpole_policies import load_cartpole_policy gamma = 0.98 alpha = 0.6 temp = 2.0 hidden_dim = 50 pi_e_name = '456_-99501_cartpole_best' pi_other_name = '400_-99761_cartpole' load = False est_policy = True env = CartpoleEnvironment(reward_reshape=False) init_state_sampler = CartpoleInitStateSampler(env) pi_e = load_cartpole_policy("cartpole_weights_1/" + pi_e_name + ".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_other = load_cartpole_policy( "cartpole_weights_1/" + pi_other_name + ".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_b = GenericMixturePolicy(pi_e, pi_other, alpha) if load: w = load_continuous_w_oracle(env, hidden_dim, './continuous_w_oracle.pt') else: w = calculate_continuous_w_oracle( env, pi_b, pi_e, gamma, hidden_dim, prefix=[ pi_e_name + "_gamma" + str(gamma) + '_', pi_other_name + "_train_" ], epoch=300) # load=False, num_tau=20, tau_len=100000, burn_in=1000, if est_policy: from estimators.infinite_horizon_estimators import oracle_w_estimator from estimators.benchmark_estimators import on_policy_estimate tau_e_path = 'tau_e_cartpole/' tau_b_path = 'tau_b_cartpole/' tau_len = 200000 burn_in = 100000 test_data = TauListDataset.load(tau_b_path, pi_other_name + '_train_') test_data_pi_e = TauListDataset.load(tau_e_path, prefix=pi_e_name + "_gamma" + str(gamma) + '_') test_data_loader = test_data.get_data_loader(1024) policy_val_estimate = w_estimator( tau_list_data_loader=test_data_loader, pi_e=pi_e, pi_b=pi_b, w=w) # policy_val_estimate = oracle_w_estimator( # tau_list_data_loader=test_data_loader, pi_e=pi_e, # pi_b=pi_b, w_oracle=w) # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma, # num_tau=1, tau_len=1000000, # load_path=tau_e_path) policy_val_oracle = float(test_data_pi_e.r.mean().detach()) squared_error = (policy_val_estimate - policy_val_oracle)**2 print('W orcacle estimates & true policy value ', policy_val_estimate, policy_val_oracle) print("Test policy val squared error:", squared_error)
def debug_RKHS_method(): from environments.cartpole_environment import CartpoleEnvironment from policies.mixture_policies import GenericMixturePolicy from policies.cartpole_policies import load_cartpole_policy gamma = 0.98 alpha = 0.6 temp = 2.0 hidden_dim = 50 pi_e_name = '456_-99501_cartpole_best' pi_other_name = '400_-99761_cartpole' load = True est_policy = False load_path = 'tau_b_cartpole/' prefix = pi_other_name + '_train_' env = CartpoleEnvironment(reward_reshape=False) init_state_sampler = CartpoleInitStateSampler(env) pi_e = load_cartpole_policy("cartpole_weights_1/" + pi_e_name + ".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_other = load_cartpole_policy( "cartpole_weights_1/" + pi_other_name + ".pt", temp, env.state_dim, hidden_dim, env.num_a) pi_b = GenericMixturePolicy(pi_e, pi_other, alpha) if load: # s_b = torch.load(open(os.path.join(load_path, prefix+'s.pt'), 'rb')) # a_b = torch.load(open(os.path.join(load_path, prefix+'a.pt'), 'rb')) # s_prime_b = torch.load( # open(os.path.join(load_path, prefix+'s_prime.pt'), 'rb')) tau_list_b = TauListDataset.load(load_path, prefix) else: tau_list_b = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in) # TODO: figure out current file and path naming rule tau_list_b.save(load_path) # s_b = tau_list_b[:][0] # # a_b = tau_list_b[:][1] # s_prime_b = tau_list_b[:][2] tau_e_path = 'tau_e_cartpole/' test_data_pi_e = TauListDataset.load(tau_e_path, prefix=pi_e_name + "_gamma" + str(gamma) + '_') policy_val_oracle = float(test_data_pi_e.r.mean().detach()) # logger = ContinuousWLogger(env, pi_e, pi_b, gamma, # False, False, None, policy_val_oracle) logger = SimplestWLogger(env, pi_e, pi_b, gamma, False, False, None, policy_val_oracle) # if load: # w = load_continuous_w_oracle( # env, hidden_dim, './continuous_w_oracle.pt') # else: # w = calculate_continuous_w_oracle( # env, pi_b, pi_e, gamma, hidden_dim, prefix=pi_other_name+"_train_", # load=True, epoch=10) w = RKHS_method(env, pi_b, pi_e, tau_list_b, gamma, hidden_dim, init_state_sampler, gaussian_kernel, logger=logger)