def generate_datasets(folder_name,
                      epsilons,
                      data_size=1e5,
                      count=1,
                      threshold=1.2):
    dir_path = os.path.dirname(
        os.path.realpath(__file__)) + F"/../../data/{folder_name}/"
    onlyfiles = set(
        [f for f in listdir(dir_path) if isfile(join(dir_path, f))])
    models_to_eval = []
    for file in onlyfiles:
        if "VALUE" in file and float(file.split('_')[-1][:-4]) > threshold:
            models_to_eval.append(file)
    for epsilon in epsilons:
        for model_name in random.sample(models_to_eval, count):
            model_path = dir_path + model_name
            env = taxi(5)
            print(
                F"generate data with model {model_name} and epsilon{epsilon}")
            Q = np.load(model_path)
            start = time.time()
            data = generate_data(env, Q, data_size, epsilon)
            data_name = "DATA_{}_EPS_{:.4}.npy".format(np.random.randint(1e5),
                                                       epsilon)
            np.save(dir_path + data_name, np.array(data, dtype=int))
            print(F"saved to {data_name}, took {time.time() - start} seconds")
def group_roll_outs():
    length = 5
    env = taxi(length)
    n_state = env.n_state
    n_action = env.n_action
    num_trajectory = 200
    truncate_size = 200


    dir_path = os.path.dirname(os.path.realpath(__file__))
    rewards = np.load(dir_path + '/taxi-q/rewards.npy')
    for i in range(0, 30):
        print(i)
        agent = Q_learning(n_state, n_action, 0.005, 0.99)
        agent.Q = np.load(dir_path + '/taxi-q/q{}.npy'.format(i))
        SAS, f, avr_reward = roll_out(n_state, env, agent.get_pi(2.0), num_trajectory, truncate_size)
        rewards[i] = avr_reward
        np.save(dir_path + '/taxi-q/rewards.npy', rewards)
def find_stable():
    length = 5
    env = taxi(length)
    n_state = env.n_state
    n_action = env.n_action
    pi_eval = np.load('taxi-policy/pi19.npy')
    gamma = 1.0

    for i in [200, 400, 600, 800, 1000]:
        for j in range(10):
            SASR_e, _, rrr = roll_out_old(n_state, env, pi_eval, i, 1000)
            roll_out_estimate = on_policy_old(SASR_e, gamma)
            print(i, "estimate using evaluation policy roll-out:",
                  roll_out_estimate)
            with open('find_stable.txt', 'a') as f:
                f.write(
                    str(i) + ' ' + str(j) + ' ' + str(roll_out_estimate) +
                    '\n')
def eval_directory(folder_name, episodes=1000, truncate_size=500):
    dir_path = os.path.dirname(
        os.path.realpath(__file__)) + F"/../../data/{folder_name}/"
    onlyfiles = set(
        [f for f in listdir(dir_path) if isfile(join(dir_path, f))])
    models_to_eval = []
    for file in onlyfiles:
        if "VALUE" not in file and "DATA" not in file and "ENV" not in file and ".npy" in file:
            models_to_eval.append(file)
    print(F"{len(models_to_eval)} models to be evaluated")
    for model_name in models_to_eval:
        model_path = dir_path + model_name
        env = taxi(5)
        print(F"evaluating model {model_name}")
        Q = np.load(model_path)
        start = time.time()
        value = roll_out(env, Q, gamma, episodes, truncate_size)
        new_name = "{}_VALUE_{:.5}.npy".format(model_name[:-4], value)
        new_path = dir_path + new_name
        os.rename(model_path, new_path)
        print(
            F"renamed {model_name} as {new_name}, took {time.time() - start} seconds"
        )
        np.sum(Q_table * pi1, 1).reshape(-1) *
        ddd) * (1 - gamma) + double_evaluation_density_ratio2(
            SASR, estimate_behavior, pi1, w_mis, gamma, Q_table, pi1)

    return est_DENR2, est_model_based, est_model_double, est_model_double2, est_model_based_mis_q, est_model_double_mis_q, est_DENR_mis_w, est_model_double_mis_w, est_model_double_mis_q2, est_model_double_mis_w2


####if __name__ == '__main__':

estimator_name = [
    'On Policy', 'Density Ratio', 'Naive Average', 'IST', 'ISS', "DRL", 'WIST',
    'WISS', 'Model Based', 'double', "Model Base misq", "Double misq", "mis_q",
    "DR wmis", "Double misq"
]
length = 5
env = taxi(length)
n_state = env.n_state
n_action = env.n_action
"""
parser = argparse.ArgumentParser(description='taxi environment')
parser.add_argument('--nt', type = int, required = False, default = num_trajectory)
parser.add_argument('--ts', type = int, required = False, default = truncate_size)
parser.add_argument('--gm', type = float, required = False, default = gamma)
args = parser.parse_args()
"""
behavior_ID = 4
target_ID = 5

pi_behavior = np.load('taxi-policy/pi19.npy')
all_constants = np.ones([2000, 6]) * 1.0 / 6
for i in range(2000):
def main():
    global global_epoch
    parser = argparse.ArgumentParser(description='taxi environment')
    parser.add_argument('--nt', type = int, required = False, default = 200)
    parser.add_argument('--ts', type = int, required = False, default = 400)
    parser.add_argument('--gm', type = float, required = False, default = 1.0)
    parser.add_argument('--save_file', type = str, required = True)
    parser.add_argument('--constraint_alpha', type = float, required = False, default = 0.3)
    parser.add_argument('--batch-size', default=128, type=int)
    args = parser.parse_args()
    args.val_writer = SummaryWriter(os.path.join(args.save_file, 'val'))


    length = 5
    env = taxi(length)
    n_state = env.n_state
    n_action = env.n_action
    print_freq=20

    alpha = np.float(0.6)
    pi_eval = np.load('taxi-policy/pi19.npy')
    pi_behavior = np.load('taxi-policy/pi3.npy')
    pi_behavior = alpha * pi_eval + (1-alpha) * pi_behavior
    
    random.seed(100)
    np.random.seed(100)
    torch.manual_seed(100)
    SASR_b_train_raw, _, _ = roll_out(n_state, env, pi_behavior, args.nt, args.ts)
    random.seed(200)
    np.random.seed(200)
    torch.manual_seed(200)
    SASR_b_val_raw, _, _ = roll_out(n_state, env, pi_behavior, args.nt, args.ts)
    random.seed(300)
    np.random.seed(300)
    torch.manual_seed(300)
    SASR_e, _, rrr = roll_out(n_state, env, pi_eval, args.nt, args.ts) #pi_e doesn't need loader
    #rrr (reward) is -0.1495
    
    SASR_b_train = SASR_Dataset(SASR_b_train_raw)
    SASR_b_val = SASR_Dataset(SASR_b_val_raw)  
    
    train_loader = torch.utils.data.DataLoader(
        SASR_b_train, batch_size=args.batch_size, shuffle=False)
    val_loader = torch.utils.data.DataLoader(
        SASR_b_val, batch_size=args.batch_size, shuffle=False)
    
    eta = get_eta(pi_eval, pi_behavior, n_state, n_action)
    gt_reward = on_policy(np.array(SASR_e), args.gm)
    print("HERE000000000 estimate using evaluation policy roll-out:", gt_reward)
    eta = torch.FloatTensor(pi_eval / pi_behavior).cuda()
    r_e = torch.FloatTensor([r_ for _, _, _, r_ in SASR_e]).cuda()
    gt_reward = float(r_e.mean())
    print("HERE estimate using evaluation policy roll-out:", gt_reward)

    
    
    w = StateEmbedding()
    f = StateEmbedding()
    for param in w.parameters():
        param.requires_grad = True
    for param in f.parameters():
        param.requires_grad = True
    if torch.cuda.is_available():
        w = w.cuda()
        f = f.cuda()
        eta = eta.cuda()
    w_optimizer = OAdam(w.parameters(), lr=2e-6, betas=(0.5, 0.9)) 
    f_optimizer = OAdam(f.parameters(), lr=1e-6, betas=(0.5, 0.9)) 
    w_optimizer_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(w_optimizer, patience=20, factor=0.5)
    f_optimizer_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(f_optimizer, patience=20, factor=0.5)

    
    b_freq_distrib = np.load('pi3_distrib.npy')
    e_freq_distrib = np.load('pi19_distrib.npy')
    gt_w = torch.Tensor(e_freq_distrib/(b_freq_distrib)).cuda()  #for numerical stability
    s=torch.linspace(0, 1999, steps=2000).cuda().long()
    
    
    s = torch.LongTensor([s_ for s_, _, _, _ in SASR_b_train]).cuda()
    a = torch.LongTensor([a_ for _, a_, _, _ in SASR_b_train]).cuda()
    r = torch.FloatTensor([r_ for _, _, _, r_ in SASR_b_train]).cuda()
    pi_b = torch.FloatTensor(pi_behavior).cuda()
    pi_e = torch.FloatTensor(pi_eval).cuda()
    gt_w_estimate = float((gt_w[s] * pi_e[s, a] / pi_b[s, a] * r).mean())
    print("HERE estimate using ground truth w:", gt_w_estimate)

    for epoch in range(5000):
#         print(epoch)
        global_epoch += 1
        dev_obj, dev_mse = validate(val_loader, w, f, args.gm, eta, gt_reward, args, gt_w)
        if epoch % print_freq == 0:
            print("epoch %d, dev objective = %f, dev mse = %f"
                      % (epoch, dev_obj, dev_mse))
            print(w(s).flatten()[:20],'\n', gt_w[:20], ' \n mean',  (w(s)).mean())
            
            torch.save({'w_model': w.state_dict(),
                       'f_model': f.state_dict()}, 
                       os.path.join(args.save_file, str(epoch)+'.pth'))
        train(train_loader, w, f, eta, args.constraint_alpha, w_optimizer, f_optimizer)  
        w_optimizer_scheduler.step(dev_mse)
        f_optimizer_scheduler.step(dev_mse)
def main():
    global global_epoch
    parser = argparse.ArgumentParser(description='taxi environment')
    parser.add_argument('--nt', type=int, required=False, default=1000)
    parser.add_argument('--ts', type=int, required=False, default=1000)
    parser.add_argument('--gm', type=float, required=False, default=0.98)
    parser.add_argument('--save_file', type=str, required=True)
    parser.add_argument('--batch-size', default=1024, type=int)
    args = parser.parse_args()
    args.val_writer = SummaryWriter(args.save_file)

    length = 5
    env = taxi(length)
    n_state = env.n_state
    n_action = env.n_action
    args.print_freq = 10
    gamma = args.gm

    alpha = np.float(0.6)
    pi_eval = np.load('taxi-policy/pi19.npy')
    # shape: (2000, 6)
    pi_behavior = np.load('taxi-policy/pi3.npy')
    pi_behavior = alpha * pi_eval + (1 - alpha) * pi_behavior

    random.seed(100)
    np.random.seed(100)
    torch.manual_seed(100)
    SASR_b_train_raw, _, _ = roll_out(n_state, env, pi_behavior, args.nt,
                                      args.ts)
    SASR_b_val_raw, _, _ = roll_out(n_state, env, pi_behavior, args.nt,
                                    args.ts)
    SASR_e, _, rrr = roll_out(n_state, env, pi_eval, args.nt,
                              args.ts)  # pi_e doesn't need loader
    # rrr (reward) is -0.1495

    SASR_b_train = SASR_Dataset(SASR_b_train_raw)
    SASR_b_val = SASR_Dataset(SASR_b_val_raw)

    train_loader = torch.utils.data.DataLoader(SASR_b_train,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(SASR_b_val,
                                             batch_size=args.batch_size,
                                             shuffle=True)

    #     eta = torch.FloatTensor(pi_eval / pi_behavior)

    q = StateEmbedding(embedding_dim=n_action)
    f = StateEmbeddingAdversary(embedding_dim=n_action)

    for param in q.parameters():
        param.requires_grad = True
    for param in f.parameters():
        param.requires_grad = True


#     if torch.cuda.is_available():
#         w = w.cuda()
#         f = f.cuda()
#         eta = eta.cuda()

    q_optimizer = OAdam(q.parameters(), lr=1e-5, betas=(0.5, 0.9))
    f_optimizer = OAdam(f.parameters(), lr=5e-5, betas=(0.5, 0.9))
    #     q_optimizer_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #         q_optimizer, patience=40, factor=0.5)
    #     f_optimizer_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #         f_optimizer, patience=40, factor=0.5)

    # SASR_b, b_freq, _ = roll_out(n_state, env, pi_behavior, 1, 10000000)
    # SASR_e, e_freq, _ = roll_out(n_state, env, pi_eval, 1, 10000000)
    # np.save('taxi7/taxi-policy/pi3_distrib.npy', b_freq/np.sum(b_freq))
    # np.save('taxi7/taxi-policy/pi19_distrib.npy', e_freq/np.sum(e_freq))

    #     b_freq_distrib = np.load('pi3_distrib.npy')
    #     e_freq_distrib = np.load('pi19_distrib.npy')

    #     gt_w = torch.Tensor(e_freq_distrib / b_freq_distrib)

    #     # estimate policy value using ground truth w
    #     s = torch.LongTensor([s_ for s_, _, _, _ in SASR_b_train])
    #     a = torch.LongTensor([a_ for _, a_, _, _ in SASR_b_train])
    #     r = torch.FloatTensor([r_ for _, _, _, r_ in SASR_b_train])

    #TODO: get roll_out_estimate and est_model_based very close; if stuck, try all_compare.py
    pi_b = torch.FloatTensor(pi_behavior)
    pi_e = torch.FloatTensor(pi_eval)

    #     gt_w_estimate = float((gt_w[s] * pi_e[s, a] / pi_b[s, a] * r).mean())
    #     print("estimate using ground truth w:", gt_w_estimate)

    # estimate policy value from evaluation policy roll out
    #     r_e = torch.FloatTensor([r_ for _, _, _, r_ in SASR_e])
    #     roll_out_estimate = float(r_e.mean())
    roll_out_estimate = -0.26128581  #on_policy(SASR_e, gamma)
    print("estimate using evaluation policy roll-out:", roll_out_estimate)

    nu0 = np.load("emp_hist.npy").reshape(-1)
    est_model_based, gt_q_table = model_based(n_state, n_action, SASR_e,
                                              pi_eval, gamma, nu0)
    print("Reward estimate using model based gt_q_table: ", est_model_based)
    #     old_est_model_based, old_gt_q_table = model_based_old(n_state, n_action, SASR_e, pi_eval, gamma)
    #     print("Reward estimate using model based old gt_q_table: ",old_est_model_based)
    #     import pdb;pdb.set_trace()

    for epoch in range(5000):
        # print(epoch)
        dev_obj, pred_reward_mse = validate(val_loader, q, f, roll_out_estimate, args, pi_e, gamma, \
                                    est_model_based, gt_q_table, nu0)
        # torch.save({'w_model': w.state_dict(),
        #             'f_model': f.state_dict()},
        #            os.path.join(args.save_file, str(epoch) + '.pth'))
        train(train_loader, q, f, q_optimizer, f_optimizer, pi_e, gamma)
        #         q_optimizer_scheduler.step(pred_reward_mse)
        #         f_optimizer_scheduler.step(pred_reward_mse)
        global_epoch += 1
Ejemplo n.º 8
0
def main():
    global global_epoch
    parser = argparse.ArgumentParser(description='taxi environment')
    parser.add_argument('--nt', type=int, required=False, default=1)
    parser.add_argument('--ts', type=int, required=False, default=100000)
    parser.add_argument('--gm', type=float, required=False, default=1.0)
    # parser.add_argument('--save_file', type=str, required=True)
    parser.add_argument('--batch-size', default=1024, type=int)
    args = parser.parse_args()
    # args.val_writer = SummaryWriter(os.path.join(args.save_file, 'val'))

    length = 5
    env = taxi(length)
    n_state = env.n_state
    n_action = env.n_action
    print_freq = 10

    alpha = np.float(0.6)
    pi_eval = np.load('./taxi-policy/pi19.npy')
    # shape: (2000, 6)
    pi_behavior = np.load('./taxi-policy/pi3.npy')
    pi_behavior = alpha * pi_eval + (1 - alpha) * pi_behavior

    random.seed(100)
    np.random.seed(100)
    torch.manual_seed(100)
    #Why delete 200 and 300?
    SASR_b_train_raw, _, _ = roll_out(n_state, env, pi_behavior, args.nt,
                                      args.ts)
    SASR_b_val_raw, b_freq, _ = roll_out(n_state, env, pi_behavior, args.nt,
                                         args.ts)
    SASR_e, e_freq, rrr = roll_out(n_state, env, pi_eval, args.nt,
                                   args.ts)  # pi_e doesn't need loader

    SASR_b_train = SASR_Dataset(SASR_b_train_raw)
    SASR_b_val = SASR_Dataset(SASR_b_val_raw)

    train_loader = torch.utils.data.DataLoader(SASR_b_train,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(SASR_b_val,
                                             batch_size=args.batch_size,
                                             shuffle=True)

    eta = torch.FloatTensor(pi_eval / pi_behavior)

    w = StateEmbedding()
    f = StateEmbeddingAdversary()

    for param in w.parameters():
        param.requires_grad = True
    for param in f.parameters():
        param.requires_grad = True


#     if torch.cuda.is_available():
#         w = w.cuda()
#         f = f.cuda()
#         eta = eta.cuda()
    w_optimizer = OAdam(w.parameters(), lr=1e-3, betas=(0.5, 0.9))
    f_optimizer = OAdam(f.parameters(), lr=5e-3, betas=(0.5, 0.9))
    # w_optimizer_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     w_optimizer, patience=20, factor=0.5)
    # f_optimizer_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     f_optimizer, patience=20, factor=0.5)

    # SASR_b, b_freq, _ = roll_out(n_state, env, pi_behavior, 1, 10000000)
    # SASR_e, e_freq, _ = roll_out(n_state, env, pi_eval, 1, 10000000)
    # np.save('taxi7/taxi-policy/pi3_distrib.npy', b_freq/np.sum(b_freq))
    # np.save('taxi7/taxi-policy/pi19_distrib.npy', e_freq/np.sum(e_freq))

    b_freq_distrib = np.load('./pi3_distrib.npy')
    e_freq_distrib = np.load('./pi19_distrib.npy')
    #     b_freq_distrib = b_freq/np.sum(b_freq)#np.load('pi3_distrib.npy')#
    #     e_freq_distrib = e_freq/np.sum(e_freq)

    gt_w = torch.Tensor(e_freq_distrib / (1e-5 + b_freq_distrib))

    # estimate policy value using ground truth w
    s = torch.LongTensor([s_ for s_, _, _, _ in SASR_b_train])
    a = torch.LongTensor([a_ for _, a_, _, _ in SASR_b_train])
    r = torch.FloatTensor([r_ for _, _, _, r_ in SASR_b_train])
    pi_b = torch.FloatTensor(pi_behavior)
    pi_e = torch.FloatTensor(pi_eval)
    gt_w_estimate = float((gt_w[s] * pi_e[s, a] / pi_b[s, a] * r).mean())
    print("estimate using ground truth w:", gt_w_estimate)

    # estimate policy value from evaluation policy roll out
    r_e = torch.FloatTensor([r_ for _, _, _, r_ in SASR_e])
    roll_out_estimate = float(r_e.mean())
    print("estimate using evaluation policy roll-out:", roll_out_estimate)

    s_all = torch.LongTensor(list(range(2000)))

    for epoch in range(5000):
        # print(epoch)
        global_epoch += 1
        dev_obj, dev_mse = validate(val_loader, w, f, args.gm, eta,
                                    roll_out_estimate, args, gt_w)
        if epoch % print_freq == 0:
            w_all = w(s_all).detach().flatten()
            w_rmse = float(((w_all - gt_w)**2).mean()**0.5)
            print("epoch %d, dev objective = %f, dev mse = %f, w rmse %f" %
                  (epoch, dev_obj, dev_mse, w_rmse))
            print("pred w:")
            print(w_all[:20])
            print("gt w:")
            print(gt_w[:20])
            print("mean w:", float(w(s).mean()))

            # torch.save({'w_model': w.state_dict(),
            #             'f_model': f.state_dict()},
            #            os.path.join(args.save_file, str(epoch) + '.pth'))
        train(train_loader, w, f, eta, w_optimizer, f_optimizer)
import gym
import numpy as np
import os
from environment import taxi

env = taxi(5)
dir_path = os.path.dirname(os.path.realpath(__file__))
num_states = env.n_state
num_actions = env.n_action
max_iterations = 1e5
delta = 10**-6

transition_iteration = 1e5
transition_epsilon = 1e-3
T, R = env.get_T_R(1e5, convergence=transition_epsilon)
np.save(dir_path + F'/../../data/taxi/ENV_T_epsilon_{transition_epsilon}.npy',
        T)
np.save(dir_path + F'/../../data/taxi/ENV_R_epsilon_{transition_epsilon}.npy',
        R)

T = np.load(dir_path +
            F'/../../data/taxi/ENV_T_epsilon_{transition_epsilon}.npy')
V = np.zeros([num_states])
Q = np.zeros([num_states, num_actions])
dones = []
gamma = 0.99

print("Taxi")
print("Actions: ", num_actions)
print("States: ", num_states)
Ejemplo n.º 10
0
def main():
    global global_epoch
    parser = argparse.ArgumentParser(description='taxi environment')
    parser.add_argument('--nt', type=int, required=False, default=1)
    parser.add_argument('--ts', type=int, required=False, default=250000)
    parser.add_argument('--gm', type=float, required=False, default=1.0)
    parser.add_argument('--print_freq', type=int, default=20)
    parser.add_argument('--batch-size', default=512, type=int)
    parser.add_argument('--save_file', required=True, type=str)
    args = parser.parse_args()
    args.val_writer = SummaryWriter(args.save_file)

    length = 5
    env = taxi(length)
    n_state = env.n_state
    n_action = env.n_action

    alpha = np.float(0.6)
    pi_eval = np.load('taxi-policy/pi19.npy')
    # shape: (2000, 6)
    pi_behavior = np.load('taxi-policy/pi3.npy')
    pi_behavior = alpha * pi_eval + (1 - alpha) * pi_behavior

    SASR_b_train = SASR_Dataset('matrix_b/train')
    SASR_b_val = SASR_Dataset('matrix_b/val')
    SASR_e = SASR_Dataset('matrix_e/test')

    train_loader = torch.utils.data.DataLoader(SASR_b_train,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(SASR_b_val,
                                             batch_size=args.batch_size,
                                             shuffle=True)

    eta = torch.FloatTensor(pi_eval / pi_behavior)

    w = SimpleCNN()
    f = SimpleCNN()
    for param in w.parameters():
        param.requires_grad = True
    for param in f.parameters():
        param.requires_grad = True
    if torch.cuda.is_available():
        w = w.cuda()
        f = f.cuda()
        eta = eta.cuda()

    # TODO: lower lr. lr depends on model
    # Want obj stablly, non-chaotic decreasing to 0. mean_mse is cheating
    w_optimizer = OAdam(w.parameters(), lr=1e-5, betas=(0.5, 0.9))
    f_optimizer = OAdam(f.parameters(), lr=5e-5, betas=(0.5, 0.9))
    w_optimizer_scheduler = torch.optim.lr_scheduler.StepLR(w_optimizer,
                                                            step_size=800,
                                                            gamma=0.1)
    f_optimizer_scheduler = torch.optim.lr_scheduler.StepLR(f_optimizer,
                                                            step_size=800,
                                                            gamma=0.1)

    # SASR_b, b_freq, _ = roll_out(n_state, env, pi_behavior, 1, 10000000)
    # SASR_e, e_freq, _ = roll_out(n_state, env, pi_eval, 1, 10000000)
    # np.save('taxi7/taxi-policy/pi3_distrib.npy', b_freq/np.sum(b_freq))
    # np.save('taxi7/taxi-policy/pi19_distrib.npy', e_freq/np.sum(e_freq))

    b_freq_distrib = np.load('pi3_distrib.npy')
    e_freq_distrib = np.load('pi19_distrib.npy')

    gt_w = torch.Tensor(e_freq_distrib / b_freq_distrib)

    # estimate policy value using ground truth w
    s = torch.LongTensor([s_ for s_, _, _, _ in SASR_b_train.get_SASR()])
    a = torch.LongTensor([a_ for _, a_, _, _ in SASR_b_train.get_SASR()])
    r = torch.FloatTensor([r_ for _, _, _, r_ in SASR_b_train.get_SASR()])
    pi_b = torch.FloatTensor(pi_behavior)
    pi_e = torch.FloatTensor(pi_eval)
    gt_w_estimate = float((gt_w[s] * pi_e[s, a] / pi_b[s, a] * r).mean())
    print("estimate using ground truth w:", gt_w_estimate)
    del pi_b
    del pi_e

    # estimate policy value from evaluation policy roll out
    r_e = torch.FloatTensor([r_ for _, _, _, r_ in SASR_e.get_SASR()])
    roll_out_estimate = float(r_e.mean())
    print("estimate using evaluation policy roll-out:", roll_out_estimate)
    del r_e

    for epoch in range(5000):
        print(epoch)
        dev_obj, dev_mse = validate(val_loader, w, f, args.gm, eta,
                                    roll_out_estimate, args, gt_w)
        #         if epoch % args.print_freq == 0:
        #             w_all = w(s_all).detach().flatten()
        #             w_rmse = float(((w_all - gt_w) ** 2).mean() ** 0.5)
        #             print("epoch %d, dev objective = %f, dev mse = %f, w rmse %f"
        #                   % (epoch, dev_obj, dev_mse, w_rmse))
        #             print("pred w:")
        #             print(w_all[:20])
        #             print("gt w:")
        #             print(gt_w[:20])
        #             print("mean w:", float(w(s).mean()))
        # torch.save({'w_model': w.state_dict(),
        #             'f_model': f.state_dict()},
        #            os.path.join(args.save_file, str(epoch) + '.pth'))
        train(train_loader, w, f, eta, w_optimizer, f_optimizer)
        w_optimizer_scheduler.step()
        f_optimizer_scheduler.step()
        #         import pdb;pdb.set_trace()
        args.val_writer.add_scalar('lr',
                                   w_optimizer_scheduler.get_lr()[0],
                                   global_epoch)
        global_epoch += 1