コード例 #1
0
def debug():
    env = TaxiEnvironment(discrete_state=False)
    pi = load_taxi_policy_continuous("taxi_data/saved_policies/pi19.npy", env)
    tau_list = env.generate_roll_out(pi=pi, num_tau=1, tau_len=10)
    s, a, s_prime, r = tau_list[0]
    print("policy probs")
    print(pi.policy_probs)
    print("")
    print("policy probs on batch")
    policy_probs = pi(s)
    print(policy_probs)
    a_probs = torch.gather(policy_probs, 1, a.view(-1, 1)).view(-1)
    print("probs of selected actions:", a_probs)
    print("mean prob:", a_probs.mean())
コード例 #2
0
def debug():
    # set up environment and policies
    env = TaxiEnvironment()
    gamma = 0.98
    alpha = 0.6
    pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy")
    pi_other = load_taxi_policy("taxi_data/saved_policies/pi3.npy")
    pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_other, pi_1_weight=alpha)

    # set up logger
    init_state_dist_path = "taxi_data/init_state_dist.npy"
    init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1)
    init_state_sampler = DiscreteInitStateSampler(init_state_dist)
    logger = DiscreteQLogger(env=env,
                             pi_e=pi_e,
                             gamma=gamma,
                             tensorboard=True,
                             save_model=True,
                             init_state_sampler=init_state_sampler)
    # generate train and val data
    train_data = env.generate_roll_out(pi=pi_b,
                                       num_tau=1,
                                       tau_len=200000,
                                       burn_in=100000)
    val_data = env.generate_roll_out(pi=pi_b,
                                     num_tau=1,
                                     tau_len=200000,
                                     burn_in=100000)

    q = train_q_taxi(env, train_data, val_data, pi_e, pi_b, init_state_sampler,
                     logger, gamma)
    # calculate final performance
    policy_val_est = q_estimator(pi_e=pi_e,
                                 gamma=gamma,
                                 q=q,
                                 init_state_sampler=init_state_sampler)
    policy_val_oracle = on_policy_estimate(env=env,
                                           pi_e=pi_e,
                                           gamma=gamma,
                                           num_tau=1,
                                           tau_len=1000000)
    squared_error = (policy_val_est - policy_val_oracle)**2
    print("Test policy val squared error:", squared_error)
コード例 #3
0
def debug():
    env = TaxiEnvironment()
    gamma = 0.98
    alpha = 0.6
    pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy")
    pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy")
    pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha)

    # set up logger
    oracle_tau_len = 1000000
    logger = DiscreteWLogger(env=env,
                             pi_e=pi_e,
                             pi_b=pi_b,
                             gamma=gamma,
                             tensorboard=False,
                             save_model=False,
                             oracle_tau_len=oracle_tau_len)

    # generate train, val, and test data
    tau_len = 200000
    burn_in = 100000
    train_data = env.generate_roll_out(pi=pi_b,
                                       num_tau=1,
                                       tau_len=tau_len,
                                       burn_in=burn_in)
    val_data = env.generate_roll_out(pi=pi_b,
                                     num_tau=1,
                                     tau_len=tau_len,
                                     burn_in=burn_in)

    # define networks and optimizers
    w = StateEmbeddingModel(num_s=env.num_s, num_out=1)
    w_lr = 1e-3
    w_optimizer = torch.optim.Adam(w.parameters(), lr=w_lr, betas=(0.5, 0.9))

    train_w_network_erm_lbfgs(train_data=train_data,
                              pi_e=pi_e,
                              pi_b=pi_b,
                              w=w,
                              gamma=gamma,
                              val_data=val_data,
                              logger=logger)
コード例 #4
0
def debug():
    env = TaxiEnvironment()
    gamma = 0.98
    alpha = 0.6
    pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy")
    pi_other = load_taxi_policy("taxi_data/saved_policies/pi3.npy")
    pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_other, pi_1_weight=alpha)
    q = StateEmbeddingModel(num_s=env.num_s, num_out=env.num_a)
    f = StateEmbeddingModel(num_s=env.num_s, num_out=env.num_a)

    # generate train and val data
    tau_list = env.generate_roll_out(pi=pi_b, num_tau=1, tau_len=4)
    s, a, s_prime, r = tau_list[0]
    print("s:", s)
    print("a:", a)
    print("s_prime:", s_prime)
    print("r:", r)
    print("")
    q_obj, f_obj = q_game_objective(q, f, s, a, s_prime, r, pi_e, gamma)
    print("q_obj:", q_obj)
    print("f_obj:", f_obj)
コード例 #5
0
def load_eval():
    env = TaxiEnvironment(discrete_state=True)
    gamma = 0.98
    alpha = 0.6
    # tau_lens = [50000, 100000, 200000, 400000]  # // 10000
    tau_len = 200000  # args.tau_len
    burn_in = 0  # 100000  # // 10000

    init_state_dist_path = "taxi_data/init_state_dist.npy"
    pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy")
    pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy")
    pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha)

    init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1)
    init_state_sampler = DiscreteInitStateSampler(init_state_dist)
    policy_val_oracle = -0.74118131399
    print('policy_val_oracle', policy_val_oracle)

    print('args.seed', args.seed)
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    random.seed(args.seed)

    train_data = env.generate_roll_out(
        pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in)
    train_data_loader = train_data.get_data_loader(1024)
    print('train_data generated')

    w_path = 'logs/2020-09-13T18:33:35.909858_w_2_200000/290_w_.pt'
    q_path = 'logs/2020-09-11T11:04:53.345691_q_2_200000/790_q_.pt'
    q_table = torch.zeros((env.num_s, env.num_a))
    q = QTableModel(q_table)
    q.model.load_state_dict(torch.load(q_path))
    w = StateEmbeddingModel(num_s=env.num_s, num_out=1)
    w.load_state_dict(torch.load(w_path))
    print('[ours] w_oracle and q loaded!')
    double_est = double_estimator(
        train_data_loader, pi_e, pi_b, w, q, gamma, init_state_sampler)
    print('[ours] drl_est', double_est)
コード例 #6
0
def debug():
    env = TaxiEnvironment(discrete_state=True)
    gamma = 0.98
    alpha = 0.6
    # tau_lens = [50000, 100000, 200000, 400000]  # // 10000
    tau_len = args.tau_len
    burn_in = 0  # 100000  # // 10000

    init_state_dist_path = "taxi_data/init_state_dist.npy"
    pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy")
    pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy")
    pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha)

    init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1)
    init_state_sampler = DiscreteInitStateSampler(init_state_dist)
    policy_val_oracle = -0.74118131399
    print('policy_val_oracle', policy_val_oracle)

    # for j in tau_lens:
    # tau_len = j
    # preds = []
    # for i in range(100):
    #     print(j, i)
    for i in range(3):
        args.seed = i
        print('args.seed', args.seed)
        np.random.seed(args.seed)
        torch.random.manual_seed(args.seed)
        random.seed(args.seed)

        train_data = env.generate_roll_out(
            pi=pi_b, num_tau=1, tau_len=tau_len, burn_in=burn_in)
        train_data_loader = train_data.get_data_loader(1024)
        val_data = None
        print('train_data generated')

        now = datetime.datetime.now()
        if args.q:
            q_path = os.path.join('logs', '_'.join(
                [str(now.isoformat()), 'q', str(args.seed), str(args.tau_len)]))
            q_logger = SimplestQLogger(
                env, pi_e, gamma, init_state_sampler, True, True, q_path, policy_val_oracle)
            q = train_q_taxi(env, train_data, val_data, pi_e, pi_b,
                             init_state_sampler, q_logger, gamma, ERM_epoch=args.q_ERM_epoch, epoch=args.q_epoch)
            q_est = q_estimator(pi_e, gamma, q, init_state_sampler)
            # SASR = train_data.restore_strcture(discrete=False)
            # dataloader doesn't preserve trajectory structure, but SASR does.
            # drl1 = double_importance_sampling_estimator(SASR, pi_b, pi_e, args.gamma, q, split_shape=[4, 1, 4, 1])
            print('[ours] q_est', q_est)
            with open(os.path.join(q_path, 'results.txt'), 'w') as f:
                f.write(str(q_est))
            print('q_est written in', q_path)

        if args.w:
            w_path = os.path.join('logs', '_'.join(
                [str(now.isoformat()), 'w', str(args.seed), str(args.tau_len)]))
            w_logger = SimplestWLogger(
                env, pi_e, pi_b, gamma, True, True, w_path, policy_val_oracle)
            w = train_w_taxi(env, train_data, val_data, pi_e, pi_b,
                             init_state_sampler, w_logger, gamma, ERM_epoch=args.w_ERM_epoch, epoch=args.w_epoch)
            w_est = w_estimator(train_data_loader, pi_e, pi_b, w)
            print('[ours] w_est', w_est)
            with open(os.path.join(w_path, 'results.txt'), 'w') as f:
                f.write(str(w_est))
            print('w_est written in', w_path)

        if args.w and args.q:
            double_est = double_estimator(
                train_data_loader, pi_e, pi_b, w, q, gamma, init_state_sampler)
            print('[ours] drl_est', double_est)
            with open(os.path.join(q_path, 'results.txt'), 'a') as f:
                f.write(str(double_est))
            print('double_est written in', q_path)
コード例 #7
0
def debug():
    env = TaxiEnvironment(discrete_state=True)
    gamma = 0.98
    alpha = 0.6
    tau_lens = [50000, 100000, 200000, 400000]  # // 10000
    burn_in = 0  # 100000  # // 10000
    init_state_dist_path = "taxi_data/init_state_dist.npy"

    pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy")
    pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy")
    pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha)

    init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1)
    init_state_sampler = DiscreteInitStateSampler(init_state_dist)
    policy_val_oracle = -0.74118131399
    print('policy_val_oracle', policy_val_oracle)

    for j in tau_lens:
        tau_len = j
        preds = []
        for i in range(100):
            print(j, i)
            np.random.seed(i)
            torch.random.manual_seed(i)
            random.seed(i)

            train_data = env.generate_roll_out(pi=pi_b,
                                               num_tau=1,
                                               tau_len=tau_len,
                                               burn_in=burn_in)
            train_data_loader = train_data.get_data_loader(1024)
            SASR = train_data.restore_strcture(discrete=True)
            # print('finished data generation.')
            # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma,
            #                                        num_tau=1, tau_len=10000000)

            is_est = importance_sampling_estimator(SASR, pi_b, pi_e, gamma)
            is_ess = importance_sampling_estimator_stepwise(
                SASR, pi_b, pi_e, gamma)
            # print('is_est', is_est, 'is_ess', is_ess)

            # # Masa's w method, only for discrete
            den_discrete = Density_Ratio_discounted(env.n_state, gamma)
            _, w_table = train_density_ratio(SASR, pi_b, pi_e, den_discrete,
                                             gamma)
            w_table = w_table.reshape(-1)
            w = StateEmbeddingModel(num_s=env.num_s, num_out=1)
            w.set_weights(torch.from_numpy(w_table))
            # print('[Masa] fitted w')

            # # Masa's q method is the same as fit_q_tabular
            q = fit_q_tabular(train_data, pi_e, gamma)
            # print('[Masa] fitted q')

            q_est = q_estimator(pi_e, gamma, q, init_state_sampler)
            drl_est = double_estimator(train_data_loader, pi_e, pi_b, w, q,
                                       gamma, init_state_sampler)
            w_est = w_estimator(train_data_loader, pi_e, pi_b, w)
            # print('[Masa] q_est', q_est, '[Masa] w_est',
            #   w_est, '[Masa] drl_est', drl_est)

            preds.append([is_est, is_ess, q_est, drl_est, w_est])

        preds = np.array(preds)
        errors = (preds - policy_val_oracle)**2
        mse = np.mean(errors, axis=0)
        print('[is_est, is_ess, q_est, drl_est, w_est] \n', preds)
        print('MSE for [is_est, is_ess, q_est, drl_est, w_est] \n', mse)
        np.save('estimators/masa_preds_' + str(j), preds)
        np.save('estimators/masa_mse_' + str(j), mse)
コード例 #8
0
def debug():
    # set up environment and policies
    env = TaxiEnvironment(discrete_state=True)
    # env = CartpoleEnvironment()
    gamma = 0.98
    alpha = 0.6
    temp = 2.0
    hidden_dim = 50
    state_dim = 4
    pi_e = load_taxi_policy("taxi_data/saved_policies/pi19.npy")
    pi_s = load_taxi_policy("taxi_data/saved_policies/pi3.npy")
    pi_b = MixtureDiscretePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha)
    # pi_e = load_taxi_policy_continuous(
    #     "taxi_data/saved_policies/pi19.npy", env)
    # pi_s = load_taxi_policy_continuous("taxi_data/saved_policies/pi3.npy", env)
    # pi_b = GenericMixturePolicy(pi_1=pi_e, pi_2=pi_s, pi_1_weight=alpha)

    # set up logger
    oracle_tau_len = 100000  # // 10000
    init_state_dist_path = "taxi_data/init_state_dist.npy"
    init_state_dist = load_tensor_from_npy(init_state_dist_path).view(-1)
    init_state_sampler = DiscreteInitStateSampler(init_state_dist)
    # init_state_sampler = DecodingDiscreteInitStateSampler(init_state_dist,
    #                                                       env.decode_state)
    # init_state_sampler = CartpoleInitStateSampler(env)
    # logger = SimpleDiscretePrintWLogger(env=env, pi_e=pi_e, pi_b=pi_b,
    #                                     gamma=gamma,
    #                                     oracle_tau_len=oracle_tau_len)
    logger = DiscreteWLogger(env=env,
                             pi_e=pi_e,
                             pi_b=pi_b,
                             gamma=gamma,
                             tensorboard=True,
                             save_model=True,
                             oracle_tau_len=oracle_tau_len)

    # generate train, val, and test data
    tau_len = 200000  # // 10000
    burn_in = 100000  # // 10000
    train_data = env.generate_roll_out(pi=pi_b,
                                       num_tau=1,
                                       tau_len=tau_len,
                                       burn_in=burn_in)
    print('finish train')
    val_data = env.generate_roll_out(pi=pi_b,
                                     num_tau=1,
                                     tau_len=tau_len,
                                     burn_in=burn_in)
    print('finish val')
    test_data = env.generate_roll_out(pi=pi_b,
                                      num_tau=1,
                                      tau_len=tau_len,
                                      burn_in=burn_in)
    print('finish test')

    w = train_w_taxi(env, train_data, val_data, pi_e, pi_b, init_state_sampler,
                     logger, gamma)
    # calculate final performance
    test_data_loader = test_data.get_data_loader(1024)
    policy_val_est = w_estimator(tau_list_data_loader=test_data_loader,
                                 pi_e=pi_e,
                                 pi_b=pi_b,
                                 w=w)
    policy_val_oracle = on_policy_estimate(env=env,
                                           pi_e=pi_e,
                                           gamma=gamma,
                                           num_tau=1,
                                           tau_len=1000000)
    squared_error = (policy_val_est - policy_val_oracle)**2
    print("Test policy val squared error:", squared_error)