def debug():
    from environments.cartpole_environment import CartpoleEnvironment
    from policies.mixture_policies import GenericMixturePolicy
    from policies.cartpole_policies import load_cartpole_policy

    gamma = 0.98
    alpha = 0.6
    temp = 2.0
    hidden_dim = 50
    pi_other_name = 'cartpole_180_-99900.0'
    load = False
    est_policy = True

    env = CartpoleEnvironment()
    pi_e = load_cartpole_policy("cartpole_weights/cartpole_best.pt", temp, env.state_dim,
                                hidden_dim, env.num_a)
    pi_other = load_cartpole_policy("cartpole_weights/"+pi_other_name+".pt", temp,
                                    env.state_dim, hidden_dim, env.num_a)
    pi_b = GenericMixturePolicy(pi_e, pi_other, alpha)

    if load:
        w = load_continuous_w_oracle(
            env, hidden_dim, './continuous_w_oracle.pt')
    else:
        w = calculate_continuous_w_oracle(
            env, pi_b, pi_e, gamma, hidden_dim, prefix=pi_other_name+"_train_",
            load=True, epoch=10)

    if est_policy:
        from dataset.tau_list_dataset import TauListDataset
        from estimators.infinite_horizon_estimators import oracle_w_estimator
        from estimators.benchmark_estimators import on_policy_estimate
        tau_e_path = 'tau_e_cartpole/'
        tau_b_path = 'tau_b_cartpole/'
        tau_len = 200000
        burn_in = 100000
        test_data = TauListDataset.load(tau_b_path, pi_other_name+'_test_')
        test_data_pi_e = TauListDataset.load(tau_e_path, prefix="gamma_")
        test_data_loader = test_data.get_data_loader(1024)
        policy_val_estimate = w_estimator(tau_list_data_loader=test_data_loader,
                                          pi_e=pi_e, pi_b=pi_b, w=w)
        # policy_val_estimate = oracle_w_estimator(
        #     tau_list_data_loader=test_data_loader, pi_e=pi_e,
        #     pi_b=pi_b, w_oracle=w)

        # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma,
        #                                        num_tau=1, tau_len=1000000,
        #                                        load_path=tau_e_path)
        policy_val_oracle = float(test_data_pi_e.r.mean().detach())
        squared_error = (policy_val_estimate - policy_val_oracle) ** 2
        print('W orcacle estimates & true policy value ',
              policy_val_estimate, policy_val_oracle)
        print("Test policy val squared error:", squared_error)
def debug():
    env = TaxiEnvironment()
    pi = RandomPolicy(num_a=6, state_rank=0)
    dataset = env.generate_roll_out(pi=pi, num_tau=2, tau_len=20)
    print(dataset.s)
    print(dataset.lens)
    dataset.save("tmp")
    dataset_loaded = TauListDataset.load("tmp")
    print(dataset_loaded.s)
    print(dataset_loaded.lens)
예제 #3
0
 def generate_roll_out(self, pi, num_tau, tau_len, burn_in=0, gamma=None):
     """
     Generates roll out (list of trajectories, each of which is a tuple
         (s, a, s', r) of pytorch arrays
     :param pi: policy for rollout, semantics are for any state s,
         pi(s) should return array of action probabilities
     :param num_tau: number of trajectories to roll out
     :param tau_len: length of each roll out trajectory
     :param burn_in: number of actions to perform at the start of each
         trajectory before we begin logging
     :param gamma: (optional) if provided, at each time step with probability
         (1 - gamma) we reset to an initial state
     :return:
     """
     self.tau_list = []
     for _ in range(num_tau):
         states = []
         actions = []
         rewards = []
         successor_states = []
         for i in range(tau_len + burn_in):
             if i == 0 or (gamma and np.random.rand() > gamma):
                 s = self.reset()
             else:
                 s = successor_states[-1]
             p = np.array(pi(s))
             p = p / p.sum()
             a = np.random.choice(list(range(self.num_a)), p=p)
             s_prime, r, done = self.step(a)
             states.append(s)
             actions.append(a)
             rewards.append(r)
             successor_states.append(s_prime)
         if self.is_s_discrete:
             s_tensor = torch.LongTensor(states[burn_in:])
             ss_tensor = torch.LongTensor(successor_states[burn_in:])
         else:
             s_tensor = torch.stack(states[burn_in:])
             ss_tensor = torch.stack(successor_states[burn_in:])
         a_tensor = torch.LongTensor(actions[burn_in:])
         r_tensor = torch.FloatTensor(rewards[burn_in:])
         self.tau_list.append((s_tensor, a_tensor, ss_tensor, r_tensor))
     return TauListDataset(self.tau_list)
def debug():
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    env = CartpoleEnvironment(reward_reshape=args.reward_reshape)
    now = datetime.datetime.now()
    timestamp = now.isoformat()
    log_path = '{}/{}/'.format(
        args.save_folder,
        '_'.join([timestamp] + [i.replace("--", "")
                                for i in sys.argv[1:]] + [args.script]))
    print(log_path)

    # set up environment and policies
    pi_e = load_cartpole_policy(
        os.path.join(args.cartpole_weights, args.pi_e_name + ".pt"), args.temp,
        env.state_dim, args.hidden_dim, env.num_a)
    pi_other = load_cartpole_policy(
        os.path.join(args.cartpole_weights, args.pi_other_name + ".pt"),
        args.temp, env.state_dim, args.hidden_dim, env.num_a)
    pi_b = GenericMixturePolicy(pi_e, pi_other, args.alpha)
    init_state_sampler = CartpoleInitStateSampler(env)
    now = datetime.datetime.now()
    if not args.rollout_dataset:
        print('Loading datasets for training...')
        train_data = TauListDataset.load(args.tau_b_path,
                                         prefix=args.pi_other_name + '_train_')
        val_data = TauListDataset.load(args.tau_b_path,
                                       prefix=args.pi_other_name + '_val_')
        pi_e_data_discounted = TauListDataset.load(
            args.tau_e_path,
            prefix=args.pi_e_name + '_gamma' + str(args.gamma) + '_')
        if args.oracle_val is None:
            args.oracle_val = float(pi_e_data_discounted.r.mean())

    else:
        # generate train, val, and test data, very slow so load exisiting data is preferred.
        # pi_b data is for training so no gamma.
        print('Not loading pi_b data, so generating')
        train_data = env.generate_roll_out(pi=pi_b,
                                           num_tau=args.num_tau,
                                           tau_len=args.tau_len,
                                           burn_in=args.burn_in)
        print('Finished generating train data of', args.pi_other_name)
        val_data = env.generate_roll_out(pi=pi_b,
                                         num_tau=args.num_tau,
                                         tau_len=args.tau_len,
                                         burn_in=args.burn_in)
        print('Finished generating val data of', args.pi_other_name)
        train_data.save(args.tau_b_path, prefix=args.pi_other_name + '_train_')
        val_data.save(args.tau_b_path, prefix=args.pi_other_name + '_val_')

        # pi_e data with gamma is only for calculating oracle policy value, so has the gamma.
        print('Not loarding pi_e data, so generating')
        pi_e_data_discounted = env.generate_roll_out(
            pi=pi_e,
            num_tau=args.num_tau,
            tau_len=args.oracle_tau_len,
            burn_in=args.burn_in,
            gamma=args.gamma)
        print('Finished generating data of pi_e with gamma')
        pi_e_data_discounted.save(args.tau_e_path,
                                  prefix=args.pi_e_name + '_gamma' +
                                  str(args.gamma) + '_')
        args.oracle_val = float(pi_e_data_discounted.r.mean())

    print('Oracle policy val', args.oracle_val)
    q_oracle = QOracleModel.load_continuous_q_oracle(env, args.hidden_dim,
                                                     env.num_a,
                                                     args.oracle_path)
    logger = ContinuousQLogger(env=env,
                               pi_e=pi_e,
                               gamma=args.gamma,
                               tensorboard=args.no_tensorboard,
                               save_model=args.no_save_model,
                               init_state_sampler=init_state_sampler,
                               log_path=log_path,
                               policy_val_oracle=args.oracle_val,
                               q_oracle=q_oracle,
                               pi_e_data_discounted=pi_e_data_discounted)
    #     logger = SimplestQLogger(
    #         env, pi_e, args.gamma, init_state_sampler, True, True, log_path, args.oracle_val)

    with open(os.path.join(log_path, 'meta.txt'), 'w') as f:
        print(args, file=f)

    q = train_q_network_cartpole(args, env, train_data, val_data, pi_e, pi_b,
                                 init_state_sampler, logger)

    # calculate final performance, optional
    policy_val_est = q_estimator(pi_e=pi_e,
                                 gamma=args.gamma,
                                 q=q,
                                 init_state_sampler=init_state_sampler)
    squared_error = (policy_val_est - logger.policy_val_oracle)**2
    print('Policy_val_oracle', logger.policy_val_oracle)
    print("Test policy val squared error:", squared_error)
def calculate_continuous_w_oracle(env,
                                  pi_b,
                                  pi_e,
                                  gamma,
                                  hidden_dim,
                                  cuda=False,
                                  lr=1e-3,
                                  tau_len=1000000,
                                  burn_in=100000,
                                  batch_size=1024,
                                  epoch=100,
                                  num_tau=1,
                                  load=True,
                                  load_path=('tau_e_cartpole/',
                                             'tau_b_cartpole/'),
                                  prefix=''):
    """
    :param env: environment (should be AbstractEnvironment)
    :param pi_b: behavior policy (should be from policies module)
    :param pi_e: evaluation policy (should be from policies module)
    :param gamma: discount factor
    :param num_s: number of different states
    :param tau_len: length to trajectory to use for monte-carlo estimate
    :param burn_in: burn-in period for monte-carlo sampling
    :return w: Network with same architecture as the trainining model

    Observation:
        Type: Box(4)
        Num	Observation                 Min         Max
        0	Cart Position             -4.8            4.8
        1	Cart Velocity             -Inf            Inf
        2	Pole Angle                 -24°           24°
        3	Pole Velocity At Tip      -Inf            Inf

    Action:
        Type: Discrete(2)
        Num	Action
        0	Push cart to the left
        1	Push cart to the right
    """
    if load:
        # The stored s_e is much longer than s_b since s_e was used only for estimating policy value.
        tau_list_b = TauListDataset.load(load_path[1], prefix[1])
        s_b = tau_list_b.s
        tau_list_e = TauListDataset.load(load_path[0], prefix[0])
        s_e = tau_list_e.s[:s_b.size(0)]

    else:
        tau_list_e = env.generate_roll_out(pi=pi_e,
                                           num_tau=num_tau,
                                           tau_len=tau_len,
                                           gamma=gamma,
                                           burn_in=burn_in)
        tau_list_b = env.generate_roll_out(pi=pi_b,
                                           num_tau=num_tau,
                                           tau_len=tau_len,
                                           burn_in=burn_in)
        # tau_list_e.save(load_path[0])
        # tau_list_b.save(load_path[1])
        s_e = tau_list_e.s  # [:][0]
        s_b = tau_list_b.s  # [:][0]

    x = StateClassifierModel(env.state_dim, 128, out_dim=2)
    train_data_loader, val_data_loader, val_len = create_oracle_datasets(
        s_e, s_b, batch_size, train_ratio=0.8)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(x.model.parameters()), lr=lr)
    optimizer_lbfgs = optim.LBFGS(x.model.parameters())

    lowest_loss = np.inf
    for i in range(epoch):
        print(i)
        x.train()

        loss_sum_train = 0.0
        norm_train = 0.0
        num_correct_sum_train = 0
        for batch_idx, (data, labels) in enumerate(train_data_loader):
            if cuda:
                data, labels = data.cuda(), labels.cuda()
            logits = x(data)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            prob = logits.softmax(-1)
            num_correct = float((prob.argmax(-1) == labels).sum())
            num_correct_sum_train += num_correct
            loss_sum_train += (float(loss.detach()) * len(data))
            norm_train += len(data)

        train_acc = num_correct_sum_train / norm_train
        avg_loss = loss_sum_train / norm_train
        print('train accuracy ', train_acc)
        print('avg loss ', avg_loss)

        if i % 1 == 0:
            x.eval()
            num_correct_sum = 0.0
            loss_sum = 0.0
            norm = 0.0
            for batch_idx, (data, labels) in enumerate(val_data_loader):
                if cuda:
                    data, labels = data.cuda(), labels.cuda()

                logits = x(data)
                prob = logits.softmax(-1)
                loss = criterion(logits, labels)
                num_correct = float((prob.argmax(-1) == labels).sum())
                num_correct_sum += num_correct
                loss_sum += (float(loss.detach()) * len(data))
                norm += len(data)
                # import pdb
                # pdb.set_trace()
                if batch_idx == 0:
                    print(prob[:50], labels[:50])

            test_acc = num_correct_sum / norm
            avg_loss = loss_sum / norm
            print('test accuracy ', test_acc)
            print('avg loss ', avg_loss)
            if avg_loss < lowest_loss:
                lowest_loss = avg_loss
                torch.save(x.state_dict(), './continuous_w_oracle.pt')
                print('best model saved with loss ', lowest_loss)

    train_data = train_data_loader.dataset[:]

    def closure_():
        optimizer_lbfgs.zero_grad()
        s_ = train_data[0]
        labels_ = train_data[1]
        logits_ = x(s_)
        loss_ = criterion(logits_, labels_)
        loss_.backward()
        return loss_

    optimizer_lbfgs.step(closure_)

    return WOracleModel(state_classifier=x,
                        reg=1e-7,
                        train_data_loader=train_data_loader)
def debug_oracle():
    from environments.cartpole_environment import CartpoleEnvironment
    from policies.mixture_policies import GenericMixturePolicy
    from policies.cartpole_policies import load_cartpole_policy

    gamma = 0.98
    alpha = 0.6
    temp = 2.0
    hidden_dim = 50
    pi_e_name = '456_-99501_cartpole_best'
    pi_other_name = '400_-99761_cartpole'
    load = False
    est_policy = True

    env = CartpoleEnvironment(reward_reshape=False)
    init_state_sampler = CartpoleInitStateSampler(env)
    pi_e = load_cartpole_policy("cartpole_weights_1/" + pi_e_name + ".pt",
                                temp, env.state_dim, hidden_dim, env.num_a)
    pi_other = load_cartpole_policy(
        "cartpole_weights_1/" + pi_other_name + ".pt", temp, env.state_dim,
        hidden_dim, env.num_a)
    pi_b = GenericMixturePolicy(pi_e, pi_other, alpha)

    if load:
        w = load_continuous_w_oracle(env, hidden_dim,
                                     './continuous_w_oracle.pt')
    else:
        w = calculate_continuous_w_oracle(
            env,
            pi_b,
            pi_e,
            gamma,
            hidden_dim,
            prefix=[
                pi_e_name + "_gamma" + str(gamma) + '_',
                pi_other_name + "_train_"
            ],
            epoch=300)
        # load=False, num_tau=20, tau_len=100000, burn_in=1000,

    if est_policy:
        from estimators.infinite_horizon_estimators import oracle_w_estimator
        from estimators.benchmark_estimators import on_policy_estimate
        tau_e_path = 'tau_e_cartpole/'
        tau_b_path = 'tau_b_cartpole/'
        tau_len = 200000
        burn_in = 100000
        test_data = TauListDataset.load(tau_b_path, pi_other_name + '_train_')
        test_data_pi_e = TauListDataset.load(tau_e_path,
                                             prefix=pi_e_name + "_gamma" +
                                             str(gamma) + '_')
        test_data_loader = test_data.get_data_loader(1024)
        policy_val_estimate = w_estimator(
            tau_list_data_loader=test_data_loader, pi_e=pi_e, pi_b=pi_b, w=w)
        # policy_val_estimate = oracle_w_estimator(
        #     tau_list_data_loader=test_data_loader, pi_e=pi_e,
        #     pi_b=pi_b, w_oracle=w)

        # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma,
        #                                        num_tau=1, tau_len=1000000,
        #                                        load_path=tau_e_path)
        policy_val_oracle = float(test_data_pi_e.r.mean().detach())
        squared_error = (policy_val_estimate - policy_val_oracle)**2
        print('W orcacle estimates & true policy value ', policy_val_estimate,
              policy_val_oracle)
        print("Test policy val squared error:", squared_error)
def debug_RKHS_method():
    from environments.cartpole_environment import CartpoleEnvironment
    from policies.mixture_policies import GenericMixturePolicy
    from policies.cartpole_policies import load_cartpole_policy

    gamma = 0.98
    alpha = 0.6
    temp = 2.0
    hidden_dim = 50
    pi_e_name = '456_-99501_cartpole_best'
    pi_other_name = '400_-99761_cartpole'
    load = True
    est_policy = False
    load_path = 'tau_b_cartpole/'
    prefix = pi_other_name + '_train_'

    env = CartpoleEnvironment(reward_reshape=False)
    init_state_sampler = CartpoleInitStateSampler(env)
    pi_e = load_cartpole_policy("cartpole_weights_1/" + pi_e_name + ".pt",
                                temp, env.state_dim, hidden_dim, env.num_a)
    pi_other = load_cartpole_policy(
        "cartpole_weights_1/" + pi_other_name + ".pt", temp, env.state_dim,
        hidden_dim, env.num_a)
    pi_b = GenericMixturePolicy(pi_e, pi_other, alpha)
    if load:
        # s_b = torch.load(open(os.path.join(load_path, prefix+'s.pt'), 'rb'))
        # a_b = torch.load(open(os.path.join(load_path, prefix+'a.pt'), 'rb'))
        # s_prime_b = torch.load(
        #     open(os.path.join(load_path, prefix+'s_prime.pt'), 'rb'))
        tau_list_b = TauListDataset.load(load_path, prefix)
    else:
        tau_list_b = env.generate_roll_out(pi=pi_b,
                                           num_tau=1,
                                           tau_len=tau_len,
                                           burn_in=burn_in)
        # TODO: figure out current file and path naming rule
        tau_list_b.save(load_path)
        # s_b = tau_list_b[:][0]
        # # a_b = tau_list_b[:][1]
        # s_prime_b = tau_list_b[:][2]

    tau_e_path = 'tau_e_cartpole/'
    test_data_pi_e = TauListDataset.load(tau_e_path,
                                         prefix=pi_e_name + "_gamma" +
                                         str(gamma) + '_')
    policy_val_oracle = float(test_data_pi_e.r.mean().detach())
    # logger = ContinuousWLogger(env, pi_e, pi_b, gamma,
    #                            False, False, None, policy_val_oracle)
    logger = SimplestWLogger(env, pi_e, pi_b, gamma, False, False, None,
                             policy_val_oracle)

    # if load:
    #     w = load_continuous_w_oracle(
    #         env, hidden_dim, './continuous_w_oracle.pt')
    # else:
    #     w = calculate_continuous_w_oracle(
    #         env, pi_b, pi_e, gamma, hidden_dim, prefix=pi_other_name+"_train_",
    #         load=True, epoch=10)
    w = RKHS_method(env,
                    pi_b,
                    pi_e,
                    tau_list_b,
                    gamma,
                    hidden_dim,
                    init_state_sampler,
                    gaussian_kernel,
                    logger=logger)