def train_dqn(df, df_dense, df_wide, df_fail, state_dim, action_dim, memory_capacity, lr, betas, gamma, target_inter, epochs, model_path, is_double=False, is_dueling=False): dqn = DQNAgent(state_dim, action_dim, memory_capacity, lr, betas, gamma, target_inter, is_double, is_dueling) log_file = open(model_path + 'log_file', 'w+') print("backend:", dqn.device, file=log_file) print("is double DQN:", dqn.is_double, file=log_file) print("is dueling DQN:", dqn.is_dueling, file=log_file) cnt = 0 loss_queue = deque() break_flag = False for epoch in range(epochs): gc.collect() print("epoch start:", epoch, file=log_file) now = time.time() for index in range(df.shape[0] - 1): row = df.iloc[index] state_dense = df_dense[index] state_wide = df_wide[index] fail_state = df_fail[index] state = np.concatenate((state_dense, state_wide, fail_state)) action = row['action'] reward = row['reward'] next_funds = row['next_funds_channel_id'] if next_funds == '-1': next_state = np.zeros(state_dim) not_done = 0 else: next_row = df.iloc[index + 1] if next_row.uid == row.uid and next_row.funds_channel_id == next_funds: next_state_dense = df_dense[index + 1] next_state_wide = df_wide[index + 1] next_fail_state = df_fail[index + 1] next_state = np.concatenate( (next_state_dense, next_state_wide, next_fail_state)) not_done = 1 cnt += 1 else: continue dqn.store_transition(state, action, reward, not_done, next_state) if dqn.memory_counter > memory_capacity: loss = dqn.learn() loss_queue.append(loss) while len(loss_queue) > 100000: loss_queue.popleft() if dqn.learn_step_counter % 10 == 0: if epoch > 1: if len(loss_queue) >= 100000 and np.mean( loss_queue) < 0.5: print('dqn has already convergence', np.mean(loss_queue), file=log_file) break_flag = True break if dqn.learn_step_counter % 1000 == 0: sys.stdout.flush() print('==============' + str(dqn.learn_step_counter) + '-th step loss:', np.mean(loss_queue), file=log_file) print("time cost:", time.time() - now, file=log_file) if break_flag: break else: torch.save( dqn.eval_net.state_dict(), model_path + 'dqn_model_20190901_20191204_add_round_add_fail_' + str(epoch) + '-th_epoch.pkl') print(cnt, cnt / df.shape[0], file=log_file) sys.stdout.flush() torch.save( dqn.eval_net.state_dict(), model_path + 'dqn_model_20190901_20191204_add_round_add_fail.pkl') log_file.close() return dqn.eval_net.cpu().eval()