Beispiel #1
0
def train():
    # Hyper parameters
    cfg = DictConfig({
        "epochs": 20,
        "lr": 1e-4,
        "use_extrinsic": True,
        "max_episode_len": 1000,
        "min_progress": 15,
        "frames_per_state": 3,
        "action_repeats": 6,
        "gamma_q": 0.85,
        "epsilon_random": 0.1,  # Sample random action with epsilon probability
        "epsilon_greedy_switch": 5,
        "q_loss_weight": 1,
        "inverse_loss_weight": 0.5,
        "forward_loss_weight": 0.5,
        "intrinsic_weight": 1.0,
        "extrinsic_weight": 1.0,
        "video_record_frequency": 3,
    })

    # ---- setting up variables -----

    q_model = MarioModel(cfg.frames_per_state)
    icm_model = MarioICM(cfg.frames_per_state)

    optim = torch.optim.Adam(list(q_model.parameters()) +
                             list(icm_model.parameters()),
                             lr=cfg.lr)

    replay = ExperienceReplay(buffer_size=500, batch_size=100)
    env = gym_super_mario_bros.make("SuperMarioBros-v0")
    env = JoypadSpace(env, COMPLEX_MOVEMENT)

    wandb.init(
        name=f"mario_icm_{str(datetime.now().timestamp())[5:10]}",
        project="rl_talk_mario",
        config={},
        save_code=True,
        group=None,
        tags=['icm'],  # List of string tags
        notes=None,  # longer description of run
        dir=BASE_DIR,
    )

    # ----- training loop ------

    for epoch in range(cfg.epochs):
        state = env.reset()
        done = False
        current_step = 0

        intrinsic_rewards = []
        extrinsic_rewards = []
        cumulative_intrinsic_reward = 0
        cumulative_extrinsic_reward = 0
        video_buffer = []
        must_record = epoch % cfg.video_record_frequency == 0

        # Monte Carlo loop
        while not done:

            # ------------ Q Learning --------------

            if current_step == 0:
                state = prepare_initial_state(env.render("rgb_array"))
            else:
                state = prepare_multi_state(state, env.render("rgb_array"))

            q_values = q_model(state)
            action = sample_action(
                q_values,
                cfg.epsilon,
                apply_epsilon=epoch > cfg.epsilon_greedy_switch,
            )

            action_count = 0
            state2 = None
            while True:
                state2_, extrinsic_reward, done, info = env.step(action)
                if state2 is None:
                    state2 = state2_
                # env.render()
                if action_count >= cfg.action_repeats or done:
                    break
                action_count += 1
            state2 = prepare_multi_state(state, state2)

            # Add intrinsic reward
            intrinsic_reward = 10000 * get_intrinsic_reward(
                state, action, state2, icm_model)

            reward = (cfg.intrinsic_weight * min(intrinsic_reward, 10)) + (
                cfg.extrinsic_weight * extrinsic_reward)

            q_loss = get_q_loss(q_values[0][action], reward, q_model, state2,
                                cfg.gamma_q)

            replay.add(state, action, reward, state2)
            state = state2

            # ------------- ICM -------------------

            state1_batch, action_batch, reward_batch, state2_batch = replay.get_batch(
            )

            action_pred, state2_encoded, state2_pred = icm_model(
                state1_batch, action_batch, state2_batch)

            inverse_loss = F.cross_entropy(action_pred, action_batch)
            forward_loss = F.mse_loss(state2_pred, state2_encoded)

            # ------------ Learning ------------

            final_loss = ((cfg.q_loss_weight * q_loss) +
                          (cfg.inverse_loss_weight * inverse_loss) +
                          (cfg.forward_loss_weight * forward_loss))

            optim.zero_grad()
            final_loss.backward()
            optim.step()

            # ------------ updates --------------

            max_episode_len_reached = current_step >= cfg.max_episode_len
            no_progress = False  # TODO: Figure out the progress shit
            done = done or max_episode_len_reached or no_progress

            intrinsic_rewards.append(intrinsic_reward.item())
            extrinsic_rewards.append(extrinsic_reward)

            cumulative_intrinsic_reward += intrinsic_reward.item()
            cumulative_extrinsic_reward += float(extrinsic_reward)

            if must_record:
                video_buffer.append(deepcopy(env.render("rgb_array")))

            # ------------ Logging ------------

            log = DictConfig({})
            log.episode = epoch
            log.step = current_step
            log.loss = final_loss.item()
            log.intrinsic_reward = intrinsic_reward.item()
            log.extrinsic_reward = float(extrinsic_reward)
            log.summed_reward = reward.item()
            log.cumulative_intrinsic_reward = cumulative_intrinsic_reward
            log.cumulative_extrinsic_reward = cumulative_extrinsic_reward
            log.x_pos = int(info.get('x_pos', -1))

            if done:
                log.max_episode_len_reached = max_episode_len_reached
                log.no_progress = no_progress

                log.episode_length = len(intrinsic_rewards)
                log.ep_x_pos = log.x_pos
                log.mean_intrinsic_rewards = float(np.mean(intrinsic_rewards))
                log.mean_extrinsic_rewards = float(np.mean(extrinsic_rewards))
                log.mean_summed_rewards = (log.mean_intrinsic_rewards +
                                           log.mean_extrinsic_rewards) / 2

                if must_record:
                    log = dict(log)
                    log[f"video_ep{epoch}_reward{reward.item()}"] = wandb.Video(
                        _format_video(video_buffer), fps=4, format="gif")

            wandb.log(log)
            current_step += 1