def main():
    cuda0 = torch.device('cuda:0')
    v_num = 32
    # discount_gamma = 0.9
    epsilon = 0
    learning_rate = 0.0001
    epsilon_cut = 0.99
    epsilon_min = 0

    #跟新轮次设定
    iteration_num = 2

    v_state_num = 2
    batch_size = 100
    start_time = datetime.datetime.now()

    hv_env_num = 100
    envs = hvenv('../data',v_state_num, v_num, hv_env_num,step_num,iteration_num)

    hv_env_num_val= 1000
    envs_val = hvenv('../data',v_state_num, v_num, hv_env_num_val,step_num,iteration_num)


    rlnnet = DQN(2*v_num**2,2**step_num).to(cuda0)
    rln_tgt_net = DQN(2*v_num**2,2**step_num).to(cuda0)
    optimizer = torch.optim.Adam(params=rlnnet.parameters(), lr=learning_rate)
    exp_buffer = collections.deque()
    exp_buffer_val = collections.deque()

    obs = envs.reset()
    obs_val = envs_val.reset()

    reward = envs.reward()
    reward_val = envs_val.reward()

    i = 0
    # rlnnet = torch.load('../weight/16step_multi_nobn_big_8192_large_ddqn_classical_init_-3learning_rate'+str(step_num)+'step_'+str(iteration_num)+'iteration_best_model.pth')

    # rlnnet = torch.load('../weight/16step_multi_big_8192_large_ddqn_classical_init_-3learning_rate'+str(step_num)+'step_'+str(iteration_num)+'iteration_best_model.pth')
    # rlnnet = torch.load('../weight/16step_multi_big_large_ddqn_classical_init_-3learning_rate'+str(step_num)+'step_'+str(iteration_num)+'iteration_best_model.pth')


    while True:

        if i % SYC_NUM == 0:
            print('----------------------------')
            print('start_time: ' + str(start_time))
            print('i/SYC_NUM: ' + str(i / SYC_NUM))
            print('syc epsilon: ' + str(epsilon))
            print('learning_rate: '+ str(learning_rate))
            print('v_state_num: '+str(v_state_num))
            print('v_num: '+str(v_num))
            print('hv_env_num: '+str(hv_env_num))
            print('batch_size: '+str(batch_size))
            print('step_num: '+str(step_num))
            print('iteration_num: '+str(iteration_num))
            print('BUFFER_START_NUM: '+str(BUFFER_START_NUM))
            print('BUFFER_length_NUM: '+str(len(exp_buffer)))
            print('----------------------------')


            epsilon *= epsilon_cut

            # torch.save(rlnnet, '../weight/16step_multi_big_large_ddqn_classical_init_-3learning_rate'+str(step_num)+'step_'+str(iteration_num)+'iteration_best_model.pth')
            # torch.save(rlnnet, '../weight/16step_multi_big_8192_large_ddqn_classical_init_-3learning_rate'+str(step_num)+'step_'+str(iteration_num)+'iteration_best_model.pth')
            # torch.save(rlnnet,'../weight/16step_multi_nobn_big_8192_large_ddqn_classical_init_-3learning_rate'+str(step_num)+'step_'+str(iteration_num)+'iteration_best_model.pth')
            rln_tgt_net.load_state_dict(rlnnet.state_dict())

        if epsilon < epsilon_min:
            epsilon = epsilon_min

        i += 1

        rlnnet.train()
        #获取经验
        exp_buffer, obs, reward = fresh_exp_buffer(exp_buffer, rlnnet, envs, obs, epsilon, cuda0,reward,BUFFER_START_NUM,BUFFER_MAX_NUM)
        optimizer.zero_grad()
        exp_batch_index = np.random.choice(np.arange(len(exp_buffer)), size=batch_size, replace=False)
        batch = batch_sample(exp_batch_index,exp_buffer)

        loss_t = calc_loss(batch, rlnnet, rln_tgt_net, device=cuda0)
        loss_t.backward()
        optimizer.step()

        if i%100==0:
            rlnnet.eval()
            print('-------eval testing: epsilon = 0 -------')
            for _ in range(int(v_num*iteration_num/step_num)):
                exp_buffer_val, obs_val, reward_val = fresh_exp_buffer(exp_buffer_val, rlnnet, envs_val, obs_val, 0, cuda0, reward_val,
                                                       10000, 11000,True)
            print('length buffer: ' + str(len(exp_buffer_val)))
            print('-------eval end-------')
def main():
    cuda0 = torch.device('cuda:0')
    v_num = 32
    # discount_gamma = 0.9
    epsilon = 0
    learning_rate = 0.0001
    epsilon_cut = 0.99
    epsilon_min = 0

    #跟新轮次设定
    iteration_num = 1

    v_state_num = 2
    batch_size = 100
    start_time = datetime.datetime.now()

    hv_env_num = 50
    envs = hvenv('../data', v_state_num, v_num, hv_env_num, step_num,
                 iteration_num)

    hv_env_num_val = 100
    envs_val = hvenv('../data', v_state_num, v_num, hv_env_num_val, step_num,
                     iteration_num)

    rlnnet = DQN(2 * v_num**2, 2**step_num).to(cuda0)
    rln_tgt_net = DQN(2 * v_num**2, 2**step_num).to(cuda0)
    optimizer = torch.optim.Adam(params=rlnnet.parameters(), lr=learning_rate)
    exp_buffer = collections.deque()
    exp_buffer_val = collections.deque()

    obs = envs.reset()
    obs_val = envs_val.reset()

    reward = envs.reward()
    reward_val = envs_val.reward()

    i = 0

    reward_record = 0

    # rlnnet = torch.load(SAVE_PATH)

    while True:

        if i % SYC_NUM == 0:
            print('----------------------------')
            print('start_time: ' + str(start_time))
            print('i/SYC_NUM: ' + str(i / SYC_NUM))
            print('syc epsilon: ' + str(epsilon))
            print('learning_rate: ' + str(learning_rate))
            print('v_state_num: ' + str(v_state_num))
            print('v_num: ' + str(v_num))
            print('hv_env_num: ' + str(hv_env_num))
            print('batch_size: ' + str(batch_size))
            print('step_num: ' + str(step_num))
            print('iteration_num: ' + str(iteration_num))
            print('BUFFER_START_NUM: ' + str(BUFFER_START_NUM))
            print('BUFFER_length_NUM: ' + str(len(exp_buffer)))
            print('----------------------------')

            epsilon *= epsilon_cut
            rln_tgt_net.load_state_dict(rlnnet.state_dict())

        if epsilon < epsilon_min:
            epsilon = epsilon_min

        i += 1

        rlnnet.train()
        #获取经验
        exp_buffer, obs, reward = fresh_exp_buffer(exp_buffer, rlnnet, envs,
                                                   obs, epsilon, cuda0, reward,
                                                   BUFFER_START_NUM,
                                                   BUFFER_MAX_NUM)
        optimizer.zero_grad()
        exp_batch_index = np.random.choice(np.arange(len(exp_buffer)),
                                           size=batch_size,
                                           replace=False)
        batch = batch_sample(exp_batch_index, exp_buffer)

        loss_t = calc_loss(batch, rlnnet, rln_tgt_net, device=cuda0)
        loss_t.backward()
        optimizer.step()

        if i % 100 == 0:
            obs_val, reward_val, reward_record = eval(rlnnet, envs_val,
                                                      obs_val, epsilon, cuda0,
                                                      reward_val,
                                                      reward_record)