예제 #1
0
 def local_train(replay_memory, valid_angles, valid_taptimes,
                 angle_estimator, taptime_estimator, angle_target_estimator,
                 taptime_target_estimator, sess, batch_size,
                 discount_factor, total_t, update_target_estimator_every,
                 saver, checkpoint_path, save):
     while True:
         angle_loss, taptime_loss = dqn_utils.pretrain_parNN(
             replay_memory,
             valid_angles,
             valid_taptimes,
             angle_estimator,
             taptime_estimator,
             angle_target_estimator,
             taptime_target_estimator,
             sess,
             batch_size,
             discount_factor,
             angle_feed=True)
         total_t += 1  # 이 코드가 실행될 때, 불러온 checkpoint에서부터 계속 더하기 1을 해나감
         # pdb.set_trace()
         # total_t = sess.run(tf.train.get_global_step()) # 이렇게 total_t를 계산해야 main thread랑 sync가 맞을 것 같은데... 왜 자꾸 에러ㄷㄷ
         if total_t % update_target_estimator_every == 0 and save:
             saver.save(sess, checkpoint_path)
             dqn_utils.copy_model_parameters(sess, angle_estimator,
                                             angle_target_estimator)
             dqn_utils.copy_model_parameters(sess, taptime_estimator,
                                             taptime_target_estimator)
             # print("local_train || total_t:", total_t, "angle_loss:", angle_loss, "taptime_loss:", taptime_loss)
         if total_t % (update_target_estimator_every * 50) == 0:
             print("local_train || total_t:", total_t, "angle_loss:",
                   angle_loss, "taptime_loss:", taptime_loss)
             sys.stdout.flush()
예제 #2
0
                  update_target_estimator_every, saver, checkpoint_path))
        threads.append(t)
        t.start()

    # replay memory로 pre_train한 network를 쓴다면, 여기서 load
    ## --> 강화학습에서 이렇게 하면 안됨. 미리 replay memory를 만드는 건 ok지만,
    ## 게임을 해나가면서 training을 해야 함...

    ####################################################################################
    print('Start Learning!')  ### 게임을 하면서, 학습을 하면서, policy를 업데이트 ##########
    ####################################################################################

    i_episode = 0  # 전체 episode수
    i_episodes = [0] * 21  # 각 레벨별 episode수

    dqn_utils.copy_model_parameters(sess, estimator, target_estimator)

    while True:

        game_state = comm.comm_get_state(s, silent=False)
        dqn_utils.clear_screenshot(SCR_PATH + "/")

        if game_state == 'UNKNOWN':
            print("########################################################")
            print("Unknown state")
            pass

        elif game_state == 'MAIN_MENU':
            print("########################################################")
            print("Main menu state")
            pass