Exemplo n.º 1
0
def test_multiple_tasks_within_same_batch(mixed_samples: Dict[int,
                                                              Tuple[Tensor,
                                                                    Tensor,
                                                                    Tensor]],
                                          indices: slice, monkeypatch,
                                          config: Config):
    """ TODO: Write out a test that checks that when given a batch with data
    from different tasks, and when the model is multiheaded, it will use the
    right output head for each image.
    """
    setting = ClassIncrementalSetting()
    model = MultiHeadModel(
        setting=setting,
        hparams=MultiHeadModel.HParams(batch_size=30, multihead=True),
        config=config,
    )

    class MockEncoder(nn.Module):
        def forward(self, x: Tensor):
            return x.new_ones([x.shape[0], model.hidden_size])

    mock_encoder = MockEncoder()
    # monkeypatch.setattr(model, "forward", mock_encoder_forward)
    model.encoder = mock_encoder
    # model.output_task = mock_output_task

    # model.output_head = MockOutputHead(
    #     input_space=spaces.Box(0, 1, [model.hidden_size]),
    #     Actions=setting.Actions,
    #     action_space=spaces.Discrete(2),
    #     task_id=None,
    # )
    for i in range(5):
        model.output_heads[str(i)] = MockOutputHead(
            input_space=spaces.Box(0, 1, [model.hidden_size]),
            Actions=setting.Actions,
            action_space=spaces.Discrete(2),
            task_id=i,
        )
    model.output_head = model.output_heads["0"]

    xs, ys, ts = map(torch.cat, zip(*mixed_samples.values()))

    xs = xs[indices]
    ys = ys[indices]
    ts = ts[indices].int()

    obs = setting.Observations(x=xs, task_labels=ts)
    with torch.no_grad():
        forward_pass = model(obs)
        y_preds = forward_pass["y_pred"]

    assert y_preds.shape == ts.shape
    assert torch.all(y_preds == ts * xs.view([xs.shape[0], -1]).mean(1))
Exemplo n.º 2
0
def test_task_inference_sl(
    mixed_samples: Dict[int, Tuple[Tensor, Tensor, Tensor]],
    indices: slice,
    config: Config,
):
    """ TODO: Write out a test that checks that when given a batch with data
    from different tasks, and when the model is multiheaded, it will use the
    right output head for each image.
    """
    # Get a mixed batch
    xs, ys, ts = map(torch.cat, zip(*mixed_samples.values()))
    xs = xs[indices]
    ys = ys[indices]
    ts = ts[indices].int()
    obs = ClassIncrementalSetting.Observations(x=xs, task_labels=None)

    setting = ClassIncrementalSetting()
    model = MultiHeadModel(
        setting=setting,
        hparams=MultiHeadModel.HParams(batch_size=30, multihead=True),
        config=config,
    )

    class MockEncoder(nn.Module):
        def forward(self, x: Tensor):
            return x.new_ones([x.shape[0], model.hidden_size])

    mock_encoder = MockEncoder()
    model.encoder = mock_encoder

    for i in range(5):
        model.output_heads[str(i)] = MockOutputHead(
            input_space=spaces.Box(0, 1, [model.hidden_size]),
            action_space=spaces.Discrete(setting.action_space.n),
            Actions=setting.Actions,
            task_id=i,
        )
    model.output_head = model.output_heads["0"]

    forward_pass = model(obs)
    y_preds = forward_pass.actions.y_pred

    assert y_preds.shape == ts.shape