def debug():
    from environments.cartpole_environment import CartpoleEnvironment
    from policies.mixture_policies import GenericMixturePolicy
    from policies.cartpole_policies import load_cartpole_policy

    gamma = 0.98
    alpha = 0.6
    temp = 2.0
    hidden_dim = 50
    pi_other_name = 'cartpole_180_-99900.0'
    load = False
    est_policy = True

    env = CartpoleEnvironment()
    pi_e = load_cartpole_policy("cartpole_weights/cartpole_best.pt", temp, env.state_dim,
                                hidden_dim, env.num_a)
    pi_other = load_cartpole_policy("cartpole_weights/"+pi_other_name+".pt", temp,
                                    env.state_dim, hidden_dim, env.num_a)
    pi_b = GenericMixturePolicy(pi_e, pi_other, alpha)

    if load:
        w = load_continuous_w_oracle(
            env, hidden_dim, './continuous_w_oracle.pt')
    else:
        w = calculate_continuous_w_oracle(
            env, pi_b, pi_e, gamma, hidden_dim, prefix=pi_other_name+"_train_",
            load=True, epoch=10)

    if est_policy:
        from dataset.tau_list_dataset import TauListDataset
        from estimators.infinite_horizon_estimators import oracle_w_estimator
        from estimators.benchmark_estimators import on_policy_estimate
        tau_e_path = 'tau_e_cartpole/'
        tau_b_path = 'tau_b_cartpole/'
        tau_len = 200000
        burn_in = 100000
        test_data = TauListDataset.load(tau_b_path, pi_other_name+'_test_')
        test_data_pi_e = TauListDataset.load(tau_e_path, prefix="gamma_")
        test_data_loader = test_data.get_data_loader(1024)
        policy_val_estimate = w_estimator(tau_list_data_loader=test_data_loader,
                                          pi_e=pi_e, pi_b=pi_b, w=w)
        # policy_val_estimate = oracle_w_estimator(
        #     tau_list_data_loader=test_data_loader, pi_e=pi_e,
        #     pi_b=pi_b, w_oracle=w)

        # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma,
        #                                        num_tau=1, tau_len=1000000,
        #                                        load_path=tau_e_path)
        policy_val_oracle = float(test_data_pi_e.r.mean().detach())
        squared_error = (policy_val_estimate - policy_val_oracle) ** 2
        print('W orcacle estimates & true policy value ',
              policy_val_estimate, policy_val_oracle)
        print("Test policy val squared error:", squared_error)
def debug():
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    env = CartpoleEnvironment(reward_reshape=args.reward_reshape)
    now = datetime.datetime.now()
    timestamp = now.isoformat()
    log_path = '{}/{}/'.format(
        args.save_folder,
        '_'.join([timestamp] + [i.replace("--", "")
                                for i in sys.argv[1:]] + [args.script]))
    print(log_path)

    # set up environment and policies
    pi_e = load_cartpole_policy(
        os.path.join(args.cartpole_weights, args.pi_e_name + ".pt"), args.temp,
        env.state_dim, args.hidden_dim, env.num_a)
    pi_other = load_cartpole_policy(
        os.path.join(args.cartpole_weights, args.pi_other_name + ".pt"),
        args.temp, env.state_dim, args.hidden_dim, env.num_a)
    pi_b = GenericMixturePolicy(pi_e, pi_other, args.alpha)
    init_state_sampler = CartpoleInitStateSampler(env)
    now = datetime.datetime.now()
    if not args.rollout_dataset:
        print('Loading datasets for training...')
        train_data = TauListDataset.load(args.tau_b_path,
                                         prefix=args.pi_other_name + '_train_')
        val_data = TauListDataset.load(args.tau_b_path,
                                       prefix=args.pi_other_name + '_val_')
        pi_e_data_discounted = TauListDataset.load(
            args.tau_e_path,
            prefix=args.pi_e_name + '_gamma' + str(args.gamma) + '_')
        if args.oracle_val is None:
            args.oracle_val = float(pi_e_data_discounted.r.mean())

    else:
        # generate train, val, and test data, very slow so load exisiting data is preferred.
        # pi_b data is for training so no gamma.
        print('Not loading pi_b data, so generating')
        train_data = env.generate_roll_out(pi=pi_b,
                                           num_tau=args.num_tau,
                                           tau_len=args.tau_len,
                                           burn_in=args.burn_in)
        print('Finished generating train data of', args.pi_other_name)
        val_data = env.generate_roll_out(pi=pi_b,
                                         num_tau=args.num_tau,
                                         tau_len=args.tau_len,
                                         burn_in=args.burn_in)
        print('Finished generating val data of', args.pi_other_name)
        train_data.save(args.tau_b_path, prefix=args.pi_other_name + '_train_')
        val_data.save(args.tau_b_path, prefix=args.pi_other_name + '_val_')

        # pi_e data with gamma is only for calculating oracle policy value, so has the gamma.
        print('Not loarding pi_e data, so generating')
        pi_e_data_discounted = env.generate_roll_out(
            pi=pi_e,
            num_tau=args.num_tau,
            tau_len=args.oracle_tau_len,
            burn_in=args.burn_in,
            gamma=args.gamma)
        print('Finished generating data of pi_e with gamma')
        pi_e_data_discounted.save(args.tau_e_path,
                                  prefix=args.pi_e_name + '_gamma' +
                                  str(args.gamma) + '_')
        args.oracle_val = float(pi_e_data_discounted.r.mean())

    print('Oracle policy val', args.oracle_val)
    q_oracle = QOracleModel.load_continuous_q_oracle(env, args.hidden_dim,
                                                     env.num_a,
                                                     args.oracle_path)
    logger = ContinuousQLogger(env=env,
                               pi_e=pi_e,
                               gamma=args.gamma,
                               tensorboard=args.no_tensorboard,
                               save_model=args.no_save_model,
                               init_state_sampler=init_state_sampler,
                               log_path=log_path,
                               policy_val_oracle=args.oracle_val,
                               q_oracle=q_oracle,
                               pi_e_data_discounted=pi_e_data_discounted)
    #     logger = SimplestQLogger(
    #         env, pi_e, args.gamma, init_state_sampler, True, True, log_path, args.oracle_val)

    with open(os.path.join(log_path, 'meta.txt'), 'w') as f:
        print(args, file=f)

    q = train_q_network_cartpole(args, env, train_data, val_data, pi_e, pi_b,
                                 init_state_sampler, logger)

    # calculate final performance, optional
    policy_val_est = q_estimator(pi_e=pi_e,
                                 gamma=args.gamma,
                                 q=q,
                                 init_state_sampler=init_state_sampler)
    squared_error = (policy_val_est - logger.policy_val_oracle)**2
    print('Policy_val_oracle', logger.policy_val_oracle)
    print("Test policy val squared error:", squared_error)
def debug_oracle():
    from environments.cartpole_environment import CartpoleEnvironment
    from policies.mixture_policies import GenericMixturePolicy
    from policies.cartpole_policies import load_cartpole_policy

    gamma = 0.98
    alpha = 0.6
    temp = 2.0
    hidden_dim = 50
    pi_e_name = '456_-99501_cartpole_best'
    pi_other_name = '400_-99761_cartpole'
    load = False
    est_policy = True

    env = CartpoleEnvironment(reward_reshape=False)
    init_state_sampler = CartpoleInitStateSampler(env)
    pi_e = load_cartpole_policy("cartpole_weights_1/" + pi_e_name + ".pt",
                                temp, env.state_dim, hidden_dim, env.num_a)
    pi_other = load_cartpole_policy(
        "cartpole_weights_1/" + pi_other_name + ".pt", temp, env.state_dim,
        hidden_dim, env.num_a)
    pi_b = GenericMixturePolicy(pi_e, pi_other, alpha)

    if load:
        w = load_continuous_w_oracle(env, hidden_dim,
                                     './continuous_w_oracle.pt')
    else:
        w = calculate_continuous_w_oracle(
            env,
            pi_b,
            pi_e,
            gamma,
            hidden_dim,
            prefix=[
                pi_e_name + "_gamma" + str(gamma) + '_',
                pi_other_name + "_train_"
            ],
            epoch=300)
        # load=False, num_tau=20, tau_len=100000, burn_in=1000,

    if est_policy:
        from estimators.infinite_horizon_estimators import oracle_w_estimator
        from estimators.benchmark_estimators import on_policy_estimate
        tau_e_path = 'tau_e_cartpole/'
        tau_b_path = 'tau_b_cartpole/'
        tau_len = 200000
        burn_in = 100000
        test_data = TauListDataset.load(tau_b_path, pi_other_name + '_train_')
        test_data_pi_e = TauListDataset.load(tau_e_path,
                                             prefix=pi_e_name + "_gamma" +
                                             str(gamma) + '_')
        test_data_loader = test_data.get_data_loader(1024)
        policy_val_estimate = w_estimator(
            tau_list_data_loader=test_data_loader, pi_e=pi_e, pi_b=pi_b, w=w)
        # policy_val_estimate = oracle_w_estimator(
        #     tau_list_data_loader=test_data_loader, pi_e=pi_e,
        #     pi_b=pi_b, w_oracle=w)

        # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma,
        #                                        num_tau=1, tau_len=1000000,
        #                                        load_path=tau_e_path)
        policy_val_oracle = float(test_data_pi_e.r.mean().detach())
        squared_error = (policy_val_estimate - policy_val_oracle)**2
        print('W orcacle estimates & true policy value ', policy_val_estimate,
              policy_val_oracle)
        print("Test policy val squared error:", squared_error)
def debug_RKHS_method():
    from environments.cartpole_environment import CartpoleEnvironment
    from policies.mixture_policies import GenericMixturePolicy
    from policies.cartpole_policies import load_cartpole_policy

    gamma = 0.98
    alpha = 0.6
    temp = 2.0
    hidden_dim = 50
    pi_e_name = '456_-99501_cartpole_best'
    pi_other_name = '400_-99761_cartpole'
    load = True
    est_policy = False
    load_path = 'tau_b_cartpole/'
    prefix = pi_other_name + '_train_'

    env = CartpoleEnvironment(reward_reshape=False)
    init_state_sampler = CartpoleInitStateSampler(env)
    pi_e = load_cartpole_policy("cartpole_weights_1/" + pi_e_name + ".pt",
                                temp, env.state_dim, hidden_dim, env.num_a)
    pi_other = load_cartpole_policy(
        "cartpole_weights_1/" + pi_other_name + ".pt", temp, env.state_dim,
        hidden_dim, env.num_a)
    pi_b = GenericMixturePolicy(pi_e, pi_other, alpha)
    if load:
        # s_b = torch.load(open(os.path.join(load_path, prefix+'s.pt'), 'rb'))
        # a_b = torch.load(open(os.path.join(load_path, prefix+'a.pt'), 'rb'))
        # s_prime_b = torch.load(
        #     open(os.path.join(load_path, prefix+'s_prime.pt'), 'rb'))
        tau_list_b = TauListDataset.load(load_path, prefix)
    else:
        tau_list_b = env.generate_roll_out(pi=pi_b,
                                           num_tau=1,
                                           tau_len=tau_len,
                                           burn_in=burn_in)
        # TODO: figure out current file and path naming rule
        tau_list_b.save(load_path)
        # s_b = tau_list_b[:][0]
        # # a_b = tau_list_b[:][1]
        # s_prime_b = tau_list_b[:][2]

    tau_e_path = 'tau_e_cartpole/'
    test_data_pi_e = TauListDataset.load(tau_e_path,
                                         prefix=pi_e_name + "_gamma" +
                                         str(gamma) + '_')
    policy_val_oracle = float(test_data_pi_e.r.mean().detach())
    # logger = ContinuousWLogger(env, pi_e, pi_b, gamma,
    #                            False, False, None, policy_val_oracle)
    logger = SimplestWLogger(env, pi_e, pi_b, gamma, False, False, None,
                             policy_val_oracle)

    # if load:
    #     w = load_continuous_w_oracle(
    #         env, hidden_dim, './continuous_w_oracle.pt')
    # else:
    #     w = calculate_continuous_w_oracle(
    #         env, pi_b, pi_e, gamma, hidden_dim, prefix=pi_other_name+"_train_",
    #         load=True, epoch=10)
    w = RKHS_method(env,
                    pi_b,
                    pi_e,
                    tau_list_b,
                    gamma,
                    hidden_dim,
                    init_state_sampler,
                    gaussian_kernel,
                    logger=logger)
示例#5
0
def debug():
    env = CartpoleEnvironment(reward_reshape=args.reward_reshape)
    # [50000, 100000, 200000, 400000] // 10000
    tau_lens = [400000]
    burn_in = 100000  # // 10000

    pi_e = load_cartpole_policy(
        os.path.join(args.cartpole_weights, args.pi_e_name + ".pt"), args.temp,
        env.state_dim, args.hidden_dim, env.num_a)
    pi_other = load_cartpole_policy(
        os.path.join(args.cartpole_weights, args.pi_other_name + ".pt"),
        args.temp, env.state_dim, args.hidden_dim, env.num_a)
    pi_b = GenericMixturePolicy(pi_e, pi_other, args.alpha)
    init_state_sampler = CartpoleInitStateSampler(env)
    print('policy_val_oracle', args.oracle_val)

    # combinations=['is_est', 'is_ess', 'q_ERM', 'w_RKHS', 'drl_q_ERM_w_RKHS', 'q_GMM', 'w_GMM', 'drl_q_GMM_w_RKHS', 'drl_q_ERM_w_GMM', 'drl_q_GMM_w_GMM']

    for j in tau_lens:
        tau_len = j
        preds = []
        for i in range(10, end):
            print(j, i)
            np.random.seed(i)
            torch.random.manual_seed(i)
            random.seed(i)

            train_data = env.generate_roll_out(pi=pi_b,
                                               num_tau=1,
                                               tau_len=tau_len,
                                               burn_in=args.burn_in)
            # train_data = TauListDataset.load(
            #     args.tau_b_path, prefix=args.pi_other_name+'_train_')
            train_data_loader = train_data.get_data_loader(args.batch_size)
            SASR = train_data.restore_strcture(discrete=False)
            print('finished data generation.')
            # policy_val_oracle = on_policy_estimate(env=env, pi_e=pi_e, gamma=gamma, num_tau=1, tau_len=10000000)

            is_est = importance_sampling_estimator(SASR,
                                                   pi_b,
                                                   pi_e,
                                                   args.gamma,
                                                   split_shape=[4, 1, 4, 1])
            is_ess = importance_sampling_estimator_stepwise(
                SASR, pi_b, pi_e, args.gamma, split_shape=[4, 1, 4, 1])
            print('is_est', is_est, 'is_ess', is_ess)

            # RKHS for w
            now = datetime.datetime.now()
            timestamp = now.isoformat()
            log_path = '{}/{}/'.format(
                args.save_folder,
                '_'.join([timestamp] +
                         [i.replace("--", "") for i in sys.argv[1:]] +
                         [args.script + '_W_' + str(j) + str(i)]))
            w_logger = SimplestWLogger(env, pi_e, pi_b, args.gamma, True, True,
                                       log_path, args.oracle_val)
            # w_logger = None

            w_rkhs = RKHS_method(
                env,
                pi_b,
                pi_e,
                train_data,
                args.gamma,
                args.hidden_dim,
                init_state_sampler,
                gaussian_kernel,
                logger=w_logger
            )  # , batch_size=args.batch_size, epoch=args.w_RKHS_epoch, lr=args.w_rkhs_lr)

            w_rkhs_est = w_estimator(train_data_loader, pi_e, pi_b, w_rkhs)
            print('w_rkhs_est', w_rkhs_est)

            # ERM
            now = datetime.datetime.now()
            timestamp = now.isoformat()
            log_path = '{}/{}/'.format(
                args.save_folder,
                '_'.join([timestamp] +
                         [i.replace("--", "") for i in sys.argv[1:]] +
                         [args.script + '_Q_' + str(j) + str(i)]))
            q_logger = SimplestQLogger(env, pi_e, args.gamma,
                                       init_state_sampler, True, True,
                                       log_path, args.oracle_val)
            # q_logger = None

            q_erm = QNetworkModelSimple(env.state_dim,
                                        args.hidden_dim,
                                        out_dim=env.num_a,
                                        neg_output=True)
            q_erm_optimizer = OAdam(q_erm.parameters(),
                                    lr=args.q_lr,
                                    betas=(0.5, 0.9))
            train_q_network_erm(train_data,
                                pi_e,
                                args.q_GMM_epoch + args.q_ERM_epoch,
                                args.batch_size,
                                q_erm,
                                q_erm_optimizer,
                                args.gamma,
                                val_data=None,
                                val_freq=10,
                                q_scheduler=None,
                                logger=q_logger)
            q_erm_est = q_estimator(pi_e, args.gamma, q_erm,
                                    init_state_sampler)
            print('q_erm_est', q_erm_est)

            # q GMM
            q_GMM = train_q_network_cartpole(args, env, train_data, None, pi_e,
                                             pi_b, init_state_sampler,
                                             q_logger)
            q_GMM_est = q_estimator(pi_e, args.gamma, q_GMM,
                                    init_state_sampler)
            print('q_GMM_est', q_GMM_est)

            # w GMM
            w_GMM = train_w_network_cartpole(args, env, train_data, None, pi_e,
                                             pi_b, init_state_sampler,
                                             w_logger)
            w_GMM_est = w_estimator(train_data_loader, pi_e, pi_b, w_GMM)
            print('w_GMM_est', w_GMM_est)

            drl_q_ERM_w_RKHS = double_estimator(train_data_loader, pi_e, pi_b,
                                                w_rkhs, q_erm, args.gamma,
                                                init_state_sampler)
            drl_q_ERM_w_GMM = double_estimator(train_data_loader, pi_e, pi_b,
                                               w_GMM, q_erm, args.gamma,
                                               init_state_sampler)
            drl_q_GMM_w_RKHS = double_estimator(train_data_loader, pi_e, pi_b,
                                                w_rkhs, q_GMM, args.gamma,
                                                init_state_sampler)
            drl_q_GMM_w_GMM = double_estimator(train_data_loader, pi_e, pi_b,
                                               w_GMM, q_GMM, args.gamma,
                                               init_state_sampler)
            print('drl_q_ERM_w_RKHS', drl_q_ERM_w_RKHS, 'drl_q_ERM_w_GMM',
                  drl_q_ERM_w_GMM, 'drl_q_GMM_w_RKHS', drl_q_GMM_w_RKHS,
                  'drl_q_GMM_w_GMM', drl_q_GMM_w_GMM)

            preds.append([
                is_est, is_ess, w_rkhs_est, q_erm_est, q_GMM_est, w_GMM_est,
                drl_q_ERM_w_RKHS, drl_q_ERM_w_GMM, drl_q_GMM_w_RKHS,
                drl_q_GMM_w_GMM
            ])

        print(
            '[is_est, is_ess, w_rkhs_est, q_erm_est, q_GMM_est, w_GMM_est, drl_q_ERM_w_RKHS, drl_q_ERM_w_GMM, drl_q_GMM_w_RKHS, drl_q_GMM_w_GMM] \n',
            preds)
        preds = np.array(preds)
        errors = (preds - args.oracle_val)**2
        mse = np.mean(errors, axis=0)
        print(
            'MSE for [is_est, is_ess, w_rkhs_est, q_erm_est, q_GMM_est, w_GMM_est, drl_q_ERM_w_RKHS, drl_q_ERM_w_GMM, drl_q_GMM_w_RKHS, drl_q_GMM_w_GMM] \n',
            mse)
        np.save('estimators/benchmarks_GMM_cont_' + str(j) + str(end), preds)
        np.save('estimators/benchmarks_GMM_cont_' + str(j) + str(end), mse)