コード例 #1
0
def main_loop(handle, possible_actions: list, model: Model,
              target_model: Model):
    exp_schedule = ExplorationScheduler()
    target_model.load_state_dict(model.state_dict())
    optimizer = torch.optim.RMSprop(model.parameters())
    with mss() as sct:
        counter = 0
        frame_counter = 0
        frame_skip_counter = 0
        score = 0
        lives = 3
        frame_times = [0, 0, 0, 0]
        replay_buffer = ReplayBuffer(
            REPLAY_BUFFER_SIZE, (3 * FRAMES_FEED, RESIZE_HEIGHT, RESIZE_WIDTH),
            FRAMES_FEED,
            baseline_priority=1,
            gamma=GAMMA,
            reward_steps=N_STEP_REWARD)
        t = 0
        action = 0
        while True:
            if not active:
                time.sleep(
                    0.5
                )  # Wait some time and check if recording should be resumed.
                continue

            startMillis = time.time()  # Time

            # Grab frames
            frame, frame_cv2 = grab_screen(monitor, sct)

            # Show frame
            if DEBUG:
                cv2.imshow('window1', frame_cv2)
            # Check if frame will be skipped. Not skipped if counter is 0
            if frame_skip_counter == 0:
                reward, score, lives = get_reward(handle, lives, score)

                # print(action, reward)
                if replay_buffer.waiting_for_effect:
                    replay_buffer.add_effects(action, reward)
                replay_buffer.push_frame(frame)
                if replay_buffer.buffer_init(
                ) and np.random.random() > exp_schedule.value(t):
                    action = choose_action(replay_buffer.encode_last_frame(),
                                           model)
                else:
                    action = np.random.randint(0, len(possible_actions))

                execute_actions([possible_actions[int(action)]
                                 ]),  # dk.SCANCODES["z"]

                # Logic to deal with a ready datapoint
                if replay_buffer.can_sample(
                        BATCH_SIZE) and t % TRAIN_FREQ == 0:
                    if PAUSE_ON_TRAIN:
                        pause_game()
                    for _ in range(BATCHES_PER_TRAIN):
                        optimize_model(model,
                                       target_model,
                                       replay_buffer,
                                       optimizer,
                                       num_actions=len(possible_actions))
                    if PAUSE_ON_TRAIN:
                        pause_game()

                # Copy model weights to target
                if t % TARGET_MODEL_UPDATE_FREQ == 0:
                    print("Saving model")
                    state_dict = model.state_dict()
                    torch.save(state_dict, MODEL_PATH)
                    print("done pickling")
                    target_model.load_state_dict(state_dict)
                    target_model.eval()

            frame_skip_counter += 1
            frame_skip_counter = frame_skip_counter % FRAMES_SKIP

            # Frame timings and other utility
            endMillis = time.time()
            frame_time = endMillis - startMillis
            frame_times[counter % 4] = frame_time
            t += 1
            # if counter % 4 == 0:
            #    print("frame time: %s" % (np.mean(frame_times)))
            counter += 1
            if cv2.waitKey(25) & 0xFF == ord('q'):
                cv2.destroyAllWindows()
                break