Example #1
0
def train_dqn(df,
              df_dense,
              df_wide,
              df_fail,
              state_dim,
              action_dim,
              memory_capacity,
              lr,
              betas,
              gamma,
              target_inter,
              epochs,
              model_path,
              is_double=False,
              is_dueling=False):
    dqn = DQNAgent(state_dim, action_dim, memory_capacity, lr, betas, gamma,
                   target_inter, is_double, is_dueling)
    log_file = open(model_path + 'log_file', 'w+')

    print("backend:", dqn.device, file=log_file)
    print("is double DQN:", dqn.is_double, file=log_file)
    print("is dueling DQN:", dqn.is_dueling, file=log_file)

    cnt = 0
    loss_queue = deque()
    break_flag = False

    for epoch in range(epochs):
        gc.collect()
        print("epoch start:", epoch, file=log_file)
        now = time.time()
        for index in range(df.shape[0] - 1):
            row = df.iloc[index]
            state_dense = df_dense[index]
            state_wide = df_wide[index]
            fail_state = df_fail[index]
            state = np.concatenate((state_dense, state_wide, fail_state))
            action = row['action']
            reward = row['reward']
            next_funds = row['next_funds_channel_id']
            if next_funds == '-1':
                next_state = np.zeros(state_dim)
                not_done = 0
            else:
                next_row = df.iloc[index + 1]
                if next_row.uid == row.uid and next_row.funds_channel_id == next_funds:
                    next_state_dense = df_dense[index + 1]
                    next_state_wide = df_wide[index + 1]
                    next_fail_state = df_fail[index + 1]
                    next_state = np.concatenate(
                        (next_state_dense, next_state_wide, next_fail_state))
                    not_done = 1
                    cnt += 1
                else:
                    continue

            dqn.store_transition(state, action, reward, not_done, next_state)

            if dqn.memory_counter > memory_capacity:
                loss = dqn.learn()
                loss_queue.append(loss)
                while len(loss_queue) > 100000:
                    loss_queue.popleft()
                if dqn.learn_step_counter % 10 == 0:
                    if epoch > 1:
                        if len(loss_queue) >= 100000 and np.mean(
                                loss_queue) < 0.5:
                            print('dqn has already convergence',
                                  np.mean(loss_queue),
                                  file=log_file)
                            break_flag = True
                            break
                if dqn.learn_step_counter % 1000 == 0:
                    sys.stdout.flush()
                    print('==============' + str(dqn.learn_step_counter) +
                          '-th step loss:',
                          np.mean(loss_queue),
                          file=log_file)
        print("time cost:", time.time() - now, file=log_file)
        if break_flag:
            break
        else:
            torch.save(
                dqn.eval_net.state_dict(), model_path +
                'dqn_model_20190901_20191204_add_round_add_fail_' +
                str(epoch) + '-th_epoch.pkl')

    print(cnt, cnt / df.shape[0], file=log_file)
    sys.stdout.flush()
    torch.save(
        dqn.eval_net.state_dict(),
        model_path + 'dqn_model_20190901_20191204_add_round_add_fail.pkl')
    log_file.close()
    return dqn.eval_net.cpu().eval()