Ejemplo n.º 1
0
def short_task_incremental_setting(config: Config):
    setting = TaskIncrementalSLSetting(
        dataset="mnist", nb_tasks=5, monitor_training_performance=True,
    )
    setting.config = config
    setting.prepare_data()

    setting.setup()
    # Testing this out: Shortening the train datasets:
    setting.train_datasets = [
        random_subset(task_dataset, 100) for task_dataset in setting.train_datasets
    ]
    setting.val_datasets = [
        random_subset(task_dataset, 100) for task_dataset in setting.val_datasets
    ]
    setting.test_datasets = [
        random_subset(task_dataset, 100) for task_dataset in setting.test_datasets
    ]
    assert len(setting.train_datasets) == 5
    assert len(setting.val_datasets) == 5
    assert len(setting.test_datasets) == 5
    assert all(len(dataset) == 100 for dataset in setting.train_datasets)
    assert all(len(dataset) == 100 for dataset in setting.val_datasets)
    assert all(len(dataset) == 100 for dataset in setting.test_datasets)

    # Assert that calling setup doesn't overwrite the datasets.
    setting.setup()
    assert len(setting.train_datasets) == 5
    assert len(setting.val_datasets) == 5
    assert len(setting.test_datasets) == 5
    assert all(len(dataset) == 100 for dataset in setting.train_datasets)
    assert all(len(dataset) == 100 for dataset in setting.val_datasets)
    assert all(len(dataset) == 100 for dataset in setting.test_datasets)

    return setting
Ejemplo n.º 2
0
def short_sl_track_setting(config: Config):
    setting = SettingProxy(
        ClassIncrementalSetting,
        "sl_track",
        # dataset="synbols",
        # nb_tasks=12,
        # class_order=class_order,
        # monitor_training_performance=True,
    )
    setting.config = config
    # TODO: This could be a bit more convenient.
    setting.data_dir = config.data_dir
    assert setting.config == config
    assert setting.data_dir == config.data_dir
    assert setting.nb_tasks == 12

    # For now we'll just shorten the tests by shortening the datasets.
    samples_per_task = 100
    setting.batch_size = 10

    setting.setup()
    # Testing this out: Shortening the train datasets:
    setting.train_datasets = [
        random_subset(task_dataset, samples_per_task)
        for task_dataset in setting.train_datasets
    ]
    setting.val_datasets = [
        random_subset(task_dataset, samples_per_task)
        for task_dataset in setting.val_datasets
    ]
    setting.test_datasets = [
        random_subset(task_dataset, samples_per_task)
        for task_dataset in setting.test_datasets
    ]
    assert len(setting.train_datasets) == setting.nb_tasks
    assert len(setting.val_datasets) == setting.nb_tasks
    assert len(setting.test_datasets) == setting.nb_tasks
    assert all(len(dataset) == samples_per_task for dataset in setting.train_datasets)
    assert all(len(dataset) == samples_per_task for dataset in setting.val_datasets)
    assert all(len(dataset) == samples_per_task for dataset in setting.test_datasets)

    # Assert that calling setup doesn't overwrite the datasets.
    setting.setup()

    assert len(setting.train_datasets) == setting.nb_tasks
    assert len(setting.val_datasets) == setting.nb_tasks
    assert len(setting.test_datasets) == setting.nb_tasks
    assert all(len(dataset) == samples_per_task for dataset in setting.train_datasets)
    assert all(len(dataset) == samples_per_task for dataset in setting.val_datasets)
    assert all(len(dataset) == samples_per_task for dataset in setting.test_datasets)

    return setting