def test_multitask_rl_bug_with_PL(monkeypatch):
    """ TODO: on_task_switch is called on the new observation, but we need to produce a
    loss for the output head that we were just using!
    """
    # NOTE: Tasks don't have anything to do with the task schedule. They are sampled at
    # each episode.
    max_episode_steps = 5
    setting = RLSetting(
        dataset="cartpole",
        batch_size=1,
        nb_tasks=2,
        max_episode_steps=max_episode_steps,
        add_done_to_observations=True,
        observe_state_directly=True,
    )
    assert setting._new_random_task_on_reset

    # setting = RLSetting.load_benchmark("monsterkong")
    config = Config(debug=True, verbose=True, seed=123)
    config.seed_everything()
    model = BaselineModel(
        setting=setting,
        hparams=MultiHeadModel.HParams(
            multihead=True,
            output_head=EpisodicA2C.HParams(
                accumulate_losses_before_backward=True)),
        config=config,
    )

    # TODO: Maybe add some kind of "hook" to check which losses get returned when?
    model.train()
    assert not model.automatic_optimization

    from pytorch_lightning import Trainer
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model, train_dataloader=setting.train_dataloader())

    # from pytorch_lightning import Trainer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    episodes = 0
    max_episodes = 5

    # Dict mapping from step to loss at that step.
    losses: Dict[int, List[Loss]] = defaultdict(list)

    with setting.train_dataloader() as env:
        env.seed(123)

        # env = TimeLimit(env, max_episode_steps=max_episode_steps)
        # Iterate over the environment, which yields one observation at a time:
        for step, obs in enumerate(env):
            assert isinstance(obs, RLSetting.Observations)

            step_results = model.training_step(batch=obs, batch_idx=step)
            loss_tensor: Optional[Tensor] = None

            if step > 0 and step % 5 == 0:
                assert all(obs.done), step  # Since batch_size == 1 for now.
                assert step_results is not None, (step, obs.task_labels)
                loss_tensor = step_results["loss"]
                loss: Loss = step_results["loss_object"]
                print(f"Loss at step {step}: {loss}")
                losses[step].append(loss)

                # # Manually perform the optimization step.
                # output_head_loss = loss.losses.get(model.output_head.name)
                # update_model = output_head_loss is not None and output_head_loss.requires_grad

                # assert update_model
                # model.manual_backward(loss_tensor, optimizer, retain_graph=not update_model)
                # model.optimizer_step()
                # if update_model:
                #     optimizer.step()
                #     optimizer.zero_grad()
                # else:
                #     assert False, (loss, output_head_loss, model.output_head.name)

            else:
                assert step_results is None

            print(
                f"Step {step}, episode {episodes}: x={obs[0]}, done={obs.done}, task labels: {obs.task_labels}, loss_tensor: {loss_tensor}"
            )

            if step > 100:
                break

    for step, step_losses in losses.items():
        print(f"Losses at step {step}:")
        for loss in step_losses:
            print(f"\t{loss}")
def test_multitask_rl_bug_without_PL(monkeypatch):
    """ TODO: on_task_switch is called on the new observation, but we need to produce a
    loss for the output head that we were just using!
    """
    # NOTE: Tasks don't have anything to do with the task schedule. They are sampled at
    # each episode.
    max_episode_steps = 5
    setting = RLSetting(
        dataset="cartpole",
        batch_size=1,
        nb_tasks=2,
        max_episode_steps=max_episode_steps,
        add_done_to_observations=True,
        observe_state_directly=True,
    )
    assert setting._new_random_task_on_reset

    # setting = RLSetting.load_benchmark("monsterkong")
    config = Config(debug=True, verbose=True, seed=123)
    config.seed_everything()
    model = BaselineModel(
        setting=setting,
        hparams=MultiHeadModel.HParams(
            multihead=True,
            output_head=EpisodicA2C.HParams(
                accumulate_losses_before_backward=True)),
        config=config,
    )
    # TODO: Maybe add some kind of "hook" to check which losses get returned when?
    model.train()
    # from pytorch_lightning import Trainer
    # trainer = Trainer(fast_dev_run=True)
    # trainer.fit(model, train_dataloader=setting.train_dataloader())
    # trainer.setup(model, stage="fit")

    # from pytorch_lightning import Trainer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    episodes = 0
    max_episodes = 5

    # Dict mapping from step to loss at that step.
    losses: Dict[int, Loss] = {}

    with setting.train_dataloader() as env:
        env.seed(123)
        # env = TimeLimit(env, max_episode_steps=max_episode_steps)
        # Iterate over the environment, which yields one observation at a time:
        for step, obs in enumerate(env):
            assert isinstance(obs, RLSetting.Observations)

            if step == 0:
                assert not any(obs.done)
            start_task_label = obs[1][0]

            stored_steps_in_each_head_before = {
                task_key: output_head.num_stored_steps(0)
                for task_key, output_head in model.output_heads.items()
            }
            forward_pass: ForwardPass = model.forward(observations=obs)
            rewards = env.send(forward_pass.actions)

            loss: Loss = model.get_loss(forward_pass=forward_pass,
                                        rewards=rewards,
                                        loss_name="debug")
            stored_steps_in_each_head_after = {
                task_key: output_head.num_stored_steps(0)
                for task_key, output_head in model.output_heads.items()
            }
            # if step == 5:
            #     assert False, (loss, stored_steps_in_each_head_before, stored_steps_in_each_head_after)

            if any(obs.done):
                assert loss.loss != 0., step
                assert loss.loss.requires_grad

                # Backpropagate the loss, update the models, etc etc.
                loss.loss.backward()
                model.on_after_backward()
                optimizer.step()
                model.on_before_zero_grad(optimizer)
                optimizer.zero_grad()

                # TODO: Need to let the model know than an update is happening so it can clear
                # buffers etc.

                episodes += sum(obs.done)
                losses[step] = loss
            else:
                assert loss.loss == 0.
            # TODO:
            print(
                f"Step {step}, episode {episodes}: x={obs[0]}, done={obs.done}, reward={rewards} task labels: {obs.task_labels}, loss: {loss.losses.keys()}: {loss.loss}"
            )

            if episodes > max_episodes:
                break