Exemplo n.º 1
0
def test_task_incremental_mnist(monkeypatch):
    setting = TaskIncrementalSLSetting(dataset="mnist",
                                       monitor_training_performance=True)
    total_ewc_losses_per_task = np.zeros(setting.nb_tasks)

    _training_step = EwcModel.training_step

    def wrapped_training_step(self: EwcModel, batch, batch_idx: int, *args,
                              **kwargs):
        step_results = _training_step(self, batch, batch_idx, *args, **kwargs)
        loss_object: Loss = step_results["loss_object"]
        if "ewc" in loss_object.losses:
            ewc_loss_obj = loss_object.losses["ewc"]
            ewc_loss = ewc_loss_obj.total_loss
            if isinstance(ewc_loss, Tensor):
                ewc_loss = ewc_loss.detach().cpu().numpy()
            total_ewc_losses_per_task[self.current_task] += ewc_loss
        return step_results

    monkeypatch.setattr(EwcModel, "training_step", wrapped_training_step)

    _fit = EwcMethod.fit

    at_all_points_in_time = []

    def wrapped_fit(self, train_env, valid_env):
        print(
            f"starting task {self.model.current_task}: {total_ewc_losses_per_task}"
        )
        total_ewc_losses_per_task[:] = 0
        _fit(self, train_env, valid_env)
        at_all_points_in_time.append(total_ewc_losses_per_task.copy())

    monkeypatch.setattr(EwcMethod, "fit", wrapped_fit)

    # _on_epoch_end = EwcModel.on_epoch_end

    # def fake_on_epoch_end(self, *args, **kwargs):
    #     assert False, f"heyo: {total_ewc_losses_per_task}"
    #     return _on_epoch_end(self, *args, **kwargs)

    # # monkeypatch.setattr(EwcModel, "on_epoch_end", fake_on_epoch_end)
    method = EwcMethod(max_epochs=1)
    results = setting.apply(method)
    assert (at_all_points_in_time[0] == 0).all()
    assert at_all_points_in_time[1][1] != 0
    assert at_all_points_in_time[2][2] != 0
    assert at_all_points_in_time[3][3] != 0
    assert at_all_points_in_time[4][4] != 0

    assert 0.95 <= results.average_online_performance.objective
    assert 0.15 <= results.average_final_performance.objective
Exemplo n.º 2
0
def main_sl():
    """ Applies the PnnMethod in a SL Setting. """
    parser = ArgumentParser(description=__doc__,
                            add_dest_to_option_strings=False)

    # Add arguments for the Setting
    # TODO: PNN is coded for the DomainIncrementalSetting, where the action space
    # is the same for each task.
    # parser.add_arguments(DomainIncrementalSetting, dest="setting")
    parser.add_arguments(TaskIncrementalSLSetting, dest="setting")
    # TaskIncrementalSLSetting.add_argparse_args(parser, dest="setting")
    Config.add_argparse_args(parser, dest="config")

    # Add arguments for the Method:
    PnnMethod.add_argparse_args(parser, dest="method")

    args = parser.parse_args()

    # setting: TaskIncrementalSLSetting = args.setting
    setting: TaskIncrementalSLSetting = TaskIncrementalSLSetting.from_argparse_args(
        # setting: DomainIncrementalSetting = DomainIncrementalSetting.from_argparse_args(
        args,
        dest="setting",
    )
    config: Config = Config.from_argparse_args(args, dest="config")

    method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method")

    method.config = config

    results = setting.apply(method, config=config)
    print(results.summary())
    return results
Exemplo n.º 3
0
 def test_short_task_incremental_setting(
     self,
     model_type: Type[Module],
     short_task_incremental_setting: TaskIncrementalSLSetting,
     config: Config,
 ):
     method = self.Method(model=model_type)
     results = short_task_incremental_setting.apply(method, config)
     assert 0.05 < results.average_final_performance.objective
Exemplo n.º 4
0
    def configure(self, setting: TaskIncrementalSLSetting):
        """ Called before the method is applied on a setting (before training). 

        You can use this to instantiate your model, for instance, since this is
        where you get access to the observation & action spaces.
        """
        setting.batch_size = self.hparams.batch_size
        assert (
            setting.increment == setting.test_increment
        ), "Assuming same number of classes per task for training and testing."
        n_classes_per_task = {
            i: setting.num_classes_in_task(i, train=True)
            for i in range(setting.nb_tasks)
        }
        image_space: Image = setting.observation_space["x"]
        self.model = HatNet(
            image_space=image_space,
            n_classes_per_task=n_classes_per_task,
            s_hat=self.hparams.s_hat,
        )
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.hparams.learning_rate,
        )
Exemplo n.º 5
0
                break

        if old_gdumb_plugin_index is None:
            raise RuntimeError("Couldn't find the Strategy's GDumb plugin!")

        old_gdumb_plugin: _GDumbPlugin = strategy.plugins.pop(
            old_gdumb_plugin_index)
        logger.info("Replacing the GDumbPlugin with our 'patched' version.")

        new_gdumb_plugin = GDumbPlugin(mem_size=old_gdumb_plugin.mem_size)
        # NOTE: Might not be necessarily, since those should be empty, but here we also
        # copy the state from the old plugin to the new one.
        new_gdumb_plugin.ext_mem = old_gdumb_plugin.ext_mem
        new_gdumb_plugin.counter = old_gdumb_plugin.counter

        strategy.plugins.insert(old_gdumb_plugin_index, new_gdumb_plugin)
        return strategy


if __name__ == "__main__":
    setting = TaskIncrementalSLSetting(dataset="mnist",
                                       nb_tasks=5,
                                       monitor_training_performance=True)
    # Create the Method, either manually or through the command-line:
    parser = ArgumentParser(__doc__)
    parser.add_arguments(GDumbMethod, "method")
    args = parser.parse_args()
    method: GDumbMethod = args.method

    results = setting.apply(method)
Exemplo n.º 6
0
def short_task_incremental_setting(config: Config):
    setting = TaskIncrementalSLSetting(
        dataset="mnist", nb_tasks=5, monitor_training_performance=True,
    )
    setting.config = config
    setting.prepare_data()

    setting.setup()
    # Testing this out: Shortening the train datasets:
    setting.train_datasets = [
        random_subset(task_dataset, 100) for task_dataset in setting.train_datasets
    ]
    setting.val_datasets = [
        random_subset(task_dataset, 100) for task_dataset in setting.val_datasets
    ]
    setting.test_datasets = [
        random_subset(task_dataset, 100) for task_dataset in setting.test_datasets
    ]
    assert len(setting.train_datasets) == 5
    assert len(setting.val_datasets) == 5
    assert len(setting.test_datasets) == 5
    assert all(len(dataset) == 100 for dataset in setting.train_datasets)
    assert all(len(dataset) == 100 for dataset in setting.val_datasets)
    assert all(len(dataset) == 100 for dataset in setting.test_datasets)

    # Assert that calling setup doesn't overwrite the datasets.
    setting.setup()
    assert len(setting.train_datasets) == 5
    assert len(setting.val_datasets) == 5
    assert len(setting.test_datasets) == 5
    assert all(len(dataset) == 100 for dataset in setting.train_datasets)
    assert all(len(dataset) == 100 for dataset in setting.val_datasets)
    assert all(len(dataset) == 100 for dataset in setting.test_datasets)

    return setting