def main_loop(handle, possible_actions: list, model: Model, target_model: Model): exp_schedule = ExplorationScheduler() target_model.load_state_dict(model.state_dict()) optimizer = torch.optim.RMSprop(model.parameters()) with mss() as sct: counter = 0 frame_counter = 0 frame_skip_counter = 0 score = 0 lives = 3 frame_times = [0, 0, 0, 0] replay_buffer = ReplayBuffer( REPLAY_BUFFER_SIZE, (3 * FRAMES_FEED, RESIZE_HEIGHT, RESIZE_WIDTH), FRAMES_FEED, baseline_priority=1, gamma=GAMMA, reward_steps=N_STEP_REWARD) t = 0 action = 0 while True: if not active: time.sleep( 0.5 ) # Wait some time and check if recording should be resumed. continue startMillis = time.time() # Time # Grab frames frame, frame_cv2 = grab_screen(monitor, sct) # Show frame if DEBUG: cv2.imshow('window1', frame_cv2) # Check if frame will be skipped. Not skipped if counter is 0 if frame_skip_counter == 0: reward, score, lives = get_reward(handle, lives, score) # print(action, reward) if replay_buffer.waiting_for_effect: replay_buffer.add_effects(action, reward) replay_buffer.push_frame(frame) if replay_buffer.buffer_init( ) and np.random.random() > exp_schedule.value(t): action = choose_action(replay_buffer.encode_last_frame(), model) else: action = np.random.randint(0, len(possible_actions)) execute_actions([possible_actions[int(action)] ]), # dk.SCANCODES["z"] # Logic to deal with a ready datapoint if replay_buffer.can_sample( BATCH_SIZE) and t % TRAIN_FREQ == 0: if PAUSE_ON_TRAIN: pause_game() for _ in range(BATCHES_PER_TRAIN): optimize_model(model, target_model, replay_buffer, optimizer, num_actions=len(possible_actions)) if PAUSE_ON_TRAIN: pause_game() # Copy model weights to target if t % TARGET_MODEL_UPDATE_FREQ == 0: print("Saving model") state_dict = model.state_dict() torch.save(state_dict, MODEL_PATH) print("done pickling") target_model.load_state_dict(state_dict) target_model.eval() frame_skip_counter += 1 frame_skip_counter = frame_skip_counter % FRAMES_SKIP # Frame timings and other utility endMillis = time.time() frame_time = endMillis - startMillis frame_times[counter % 4] = frame_time t += 1 # if counter % 4 == 0: # print("frame time: %s" % (np.mean(frame_times))) counter += 1 if cv2.waitKey(25) & 0xFF == ord('q'): cv2.destroyAllWindows() break