def test_task_incremental_mnist(monkeypatch): setting = TaskIncrementalSLSetting(dataset="mnist", monitor_training_performance=True) total_ewc_losses_per_task = np.zeros(setting.nb_tasks) _training_step = EwcModel.training_step def wrapped_training_step(self: EwcModel, batch, batch_idx: int, *args, **kwargs): step_results = _training_step(self, batch, batch_idx, *args, **kwargs) loss_object: Loss = step_results["loss_object"] if "ewc" in loss_object.losses: ewc_loss_obj = loss_object.losses["ewc"] ewc_loss = ewc_loss_obj.total_loss if isinstance(ewc_loss, Tensor): ewc_loss = ewc_loss.detach().cpu().numpy() total_ewc_losses_per_task[self.current_task] += ewc_loss return step_results monkeypatch.setattr(EwcModel, "training_step", wrapped_training_step) _fit = EwcMethod.fit at_all_points_in_time = [] def wrapped_fit(self, train_env, valid_env): print( f"starting task {self.model.current_task}: {total_ewc_losses_per_task}" ) total_ewc_losses_per_task[:] = 0 _fit(self, train_env, valid_env) at_all_points_in_time.append(total_ewc_losses_per_task.copy()) monkeypatch.setattr(EwcMethod, "fit", wrapped_fit) # _on_epoch_end = EwcModel.on_epoch_end # def fake_on_epoch_end(self, *args, **kwargs): # assert False, f"heyo: {total_ewc_losses_per_task}" # return _on_epoch_end(self, *args, **kwargs) # # monkeypatch.setattr(EwcModel, "on_epoch_end", fake_on_epoch_end) method = EwcMethod(max_epochs=1) results = setting.apply(method) assert (at_all_points_in_time[0] == 0).all() assert at_all_points_in_time[1][1] != 0 assert at_all_points_in_time[2][2] != 0 assert at_all_points_in_time[3][3] != 0 assert at_all_points_in_time[4][4] != 0 assert 0.95 <= results.average_online_performance.objective assert 0.15 <= results.average_final_performance.objective
def main_sl(): """ Applies the PnnMethod in a SL Setting. """ parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False) # Add arguments for the Setting # TODO: PNN is coded for the DomainIncrementalSetting, where the action space # is the same for each task. # parser.add_arguments(DomainIncrementalSetting, dest="setting") parser.add_arguments(TaskIncrementalSLSetting, dest="setting") # TaskIncrementalSLSetting.add_argparse_args(parser, dest="setting") Config.add_argparse_args(parser, dest="config") # Add arguments for the Method: PnnMethod.add_argparse_args(parser, dest="method") args = parser.parse_args() # setting: TaskIncrementalSLSetting = args.setting setting: TaskIncrementalSLSetting = TaskIncrementalSLSetting.from_argparse_args( # setting: DomainIncrementalSetting = DomainIncrementalSetting.from_argparse_args( args, dest="setting", ) config: Config = Config.from_argparse_args(args, dest="config") method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method") method.config = config results = setting.apply(method, config=config) print(results.summary()) return results
def test_short_task_incremental_setting( self, model_type: Type[Module], short_task_incremental_setting: TaskIncrementalSLSetting, config: Config, ): method = self.Method(model=model_type) results = short_task_incremental_setting.apply(method, config) assert 0.05 < results.average_final_performance.objective
def configure(self, setting: TaskIncrementalSLSetting): """ Called before the method is applied on a setting (before training). You can use this to instantiate your model, for instance, since this is where you get access to the observation & action spaces. """ setting.batch_size = self.hparams.batch_size assert ( setting.increment == setting.test_increment ), "Assuming same number of classes per task for training and testing." n_classes_per_task = { i: setting.num_classes_in_task(i, train=True) for i in range(setting.nb_tasks) } image_space: Image = setting.observation_space["x"] self.model = HatNet( image_space=image_space, n_classes_per_task=n_classes_per_task, s_hat=self.hparams.s_hat, ) self.optimizer = torch.optim.Adam( self.model.parameters(), lr=self.hparams.learning_rate, )
break if old_gdumb_plugin_index is None: raise RuntimeError("Couldn't find the Strategy's GDumb plugin!") old_gdumb_plugin: _GDumbPlugin = strategy.plugins.pop( old_gdumb_plugin_index) logger.info("Replacing the GDumbPlugin with our 'patched' version.") new_gdumb_plugin = GDumbPlugin(mem_size=old_gdumb_plugin.mem_size) # NOTE: Might not be necessarily, since those should be empty, but here we also # copy the state from the old plugin to the new one. new_gdumb_plugin.ext_mem = old_gdumb_plugin.ext_mem new_gdumb_plugin.counter = old_gdumb_plugin.counter strategy.plugins.insert(old_gdumb_plugin_index, new_gdumb_plugin) return strategy if __name__ == "__main__": setting = TaskIncrementalSLSetting(dataset="mnist", nb_tasks=5, monitor_training_performance=True) # Create the Method, either manually or through the command-line: parser = ArgumentParser(__doc__) parser.add_arguments(GDumbMethod, "method") args = parser.parse_args() method: GDumbMethod = args.method results = setting.apply(method)
def short_task_incremental_setting(config: Config): setting = TaskIncrementalSLSetting( dataset="mnist", nb_tasks=5, monitor_training_performance=True, ) setting.config = config setting.prepare_data() setting.setup() # Testing this out: Shortening the train datasets: setting.train_datasets = [ random_subset(task_dataset, 100) for task_dataset in setting.train_datasets ] setting.val_datasets = [ random_subset(task_dataset, 100) for task_dataset in setting.val_datasets ] setting.test_datasets = [ random_subset(task_dataset, 100) for task_dataset in setting.test_datasets ] assert len(setting.train_datasets) == 5 assert len(setting.val_datasets) == 5 assert len(setting.test_datasets) == 5 assert all(len(dataset) == 100 for dataset in setting.train_datasets) assert all(len(dataset) == 100 for dataset in setting.val_datasets) assert all(len(dataset) == 100 for dataset in setting.test_datasets) # Assert that calling setup doesn't overwrite the datasets. setting.setup() assert len(setting.train_datasets) == 5 assert len(setting.val_datasets) == 5 assert len(setting.test_datasets) == 5 assert all(len(dataset) == 100 for dataset in setting.train_datasets) assert all(len(dataset) == 100 for dataset in setting.val_datasets) assert all(len(dataset) == 100 for dataset in setting.test_datasets) return setting