예제 #1
0
    def setup(self, stage: str = None):
        if not self.has_prepared_data:
            self.prepare_data()
        super().setup(stage=stage)

        if stage not in (None, "fit", "test"):
            raise RuntimeError(f"`stage` should be 'fit', 'test' or None.")

        if stage in (None, "fit"):
            self.train_cl_dataset = self.train_cl_dataset or self.make_dataset(
                self.config.data_dir, download=False, train=True
            )
            self.train_cl_loader = self.train_cl_loader or ClassIncremental(
                cl_dataset=self.train_cl_dataset,
                nb_tasks=self.nb_tasks,
                increment=self.increment,
                initial_increment=self.initial_increment,
                transformations=self.train_transforms,
                class_order=self.class_order,
            )
            if not self.train_datasets and not self.val_datasets:
                for task_id, train_taskset in enumerate(self.train_cl_loader):
                    train_taskset, valid_taskset = split_train_val(
                        train_taskset, val_split=0.1
                    )
                    self.train_datasets.append(train_taskset)
                    self.val_datasets.append(valid_taskset)
                # IDEA: We could do the remapping here instead of adding a wrapper later.
                if self.shared_action_space and isinstance(
                    self.action_space, spaces.Discrete
                ):
                    # If we have a shared output space, then they are all mapped to [0, n_per_task]
                    self.train_datasets = list(map(relabel, self.train_datasets))
                    self.val_datasets = list(map(relabel, self.val_datasets))

        if stage in (None, "test"):
            self.test_cl_dataset = self.test_cl_dataset or self.make_dataset(
                self.config.data_dir, download=False, train=False
            )
            self.test_cl_loader = self.test_cl_loader or ClassIncremental(
                cl_dataset=self.test_cl_dataset,
                nb_tasks=self.nb_tasks,
                increment=self.test_increment,
                initial_increment=self.test_initial_increment,
                transformations=self.test_transforms,
                class_order=self.test_class_order,
            )
            if not self.test_datasets:
                # TODO: If we decide to 'shuffle' the test tasks, then store the sequence of
                # task ids in a new property, probably here.
                # self.test_task_order = list(range(len(self.test_datasets)))
                self.test_datasets = list(self.test_cl_loader)
                # IDEA: We could do the remapping here instead of adding a wrapper later.
                if self.shared_action_space and isinstance(
                    self.action_space, spaces.Discrete
                ):
                    # If we have a shared output space, then they are all mapped to [0, n_per_task]
                    self.test_datasets = list(map(relabel, self.test_datasets))
예제 #2
0
def test_split_train_val(val_split):
    train, test = gen_data()
    dummy = InMemoryDataset(*train)
    scenario = ClassIncremental(dummy, increment=5)

    for taskset in scenario:
        train_taskset, val_taskset = split_train_val(taskset, val_split=val_split)
        assert int(val_split * len(taskset)) == len(val_taskset)
        assert len(val_taskset) + len(train_taskset) == len(taskset)
예제 #3
0
def test_split_train_val(val_split, nb_val):
    x = np.random.rand(10, 2, 2, 3)
    y = np.ones((10, ))
    t = np.ones((10, ))

    base_set = TaskSet(x, y, t, None)

    train_set, val_set = split_train_val(base_set, val_split)
    assert len(val_set) == nb_val
    assert len(train_set) + len(val_set) == len(base_set)
예제 #4
0
def test_split_train_val_loading():
    x = np.random.rand(10, 2, 2, 3)
    y = np.ones((10, ))
    t = np.ones((10, ))

    base_set = TaskSet(x, y, t, None)

    train_set, val_set = split_train_val(base_set, 0.2)

    for task_set in (train_set, val_set):
        loader = DataLoader(task_set, batch_size=32)
        for x, y, t in loader:
            pass
예제 #5
0
    def setup(self, stage: Optional[str] = None, *args, **kwargs):
        """ Creates the datasets for each task.
        TODO: Figure out a way of setting data_dir elsewhere maybe?
        """
        assert self.config
        # self.config = self.config or Config.from_args(self._argv)
        logger.debug(
            f"data_dir: {self.data_dir}, setup args: {args} kwargs: {kwargs}")

        self.train_cl_dataset = self.make_dataset(self.data_dir,
                                                  download=False,
                                                  train=True)
        self.test_cl_dataset = self.make_dataset(self.data_dir,
                                                 download=False,
                                                 train=False)

        self.train_cl_loader: _BaseScenario = self.make_train_cl_loader(
            self.train_cl_dataset)
        self.test_cl_loader: _BaseScenario = self.make_test_cl_loader(
            self.test_cl_dataset)

        logger.info(f"Number of train tasks: {self.train_cl_loader.nb_tasks}.")
        logger.info(f"Number of test tasks: {self.train_cl_loader.nb_tasks}.")

        self.train_datasets.clear()
        self.val_datasets.clear()
        self.test_datasets.clear()

        for task_id, train_dataset in enumerate(self.train_cl_loader):
            train_dataset, val_dataset = split_train_val(
                train_dataset, val_split=self.val_fraction)
            self.train_datasets.append(train_dataset)
            self.val_datasets.append(val_dataset)

        for task_id, test_dataset in enumerate(self.test_cl_loader):
            self.test_datasets.append(test_dataset)

        super().setup(stage, *args, **kwargs)

        # TODO: Adding this temporarily just for the competition
        self.test_boundary_steps = [0] + list(
            itertools.accumulate(map(len, self.test_datasets)))[:-1]
        self.test_steps = sum(map(len, self.test_datasets))
예제 #6
0
def test_h5dataset_split_train_test(data, tmpdir):
    filename_h5 = os.path.join(tmpdir, "test_h5.hdf5")

    x_, y_, t_ = data
    h5dataset = H5Dataset(x_, y_, t_, data_path=filename_h5)

    nb_task = len(np.unique(t_))
    scenario = ContinualScenario(h5dataset)

    for task_set in scenario:
        task_set_tr, task_set_val = split_train_val(task_set)
        loader_tr = DataLoader(task_set_tr)
        for _ in loader_tr:
            pass
        loader_val = DataLoader(task_set_val)
        for _ in loader_val:
            pass

    assert scenario.nb_tasks == nb_task
예제 #7
0
def test_concat_smooth_boundaries(config: Config):
    from continuum.datasets import MNIST
    from continuum.scenarios import ClassIncremental
    from continuum.tasks import split_train_val

    dataset = MNIST(config.data_dir, download=True, train=True)
    scenario = ClassIncremental(
        dataset,
        increment=2,
    )

    print(f"Number of classes: {scenario.nb_classes}.")
    print(f"Number of tasks: {scenario.nb_tasks}.")

    train_datasets = []
    valid_datasets = []
    for task_id, train_taskset in enumerate(scenario):
        train_taskset, val_taskset = split_train_val(train_taskset,
                                                     val_split=0.1)
        train_datasets.append(train_taskset)
        valid_datasets.append(val_taskset)

    # train_datasets = [Subset(task_dataset, np.arange(20)) for task_dataset in train_datasets]
    train_dataset = smooth_task_boundaries_concat(train_datasets, seed=123)

    xs = np.arange(len(train_dataset))
    y_counters: List[Counter] = []
    t_counters: List[Counter] = []
    dataloader = DataLoader(train_dataset, batch_size=100, shuffle=False)

    for x, y, t in dataloader:
        y_count = Counter(y.tolist())
        t_count = Counter(t.tolist())

        y_counters.append(y_count)
        t_counters.append(t_count)

    classes = list(set().union(*y_counters))
    nb_classes = len(classes)
    x = np.arange(len(dataloader))

    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(2)
    for label in range(nb_classes):
        y = [y_counter.get(label) for y_counter in y_counters]
        axes[0].plot(x, y, label=f"class {label}")
    axes[0].legend()
    axes[0].set_title("y")
    axes[0].set_xlabel("Batch index")
    axes[0].set_ylabel("Count in batch")

    for task_id in range(scenario.nb_tasks):
        y = [t_counter.get(task_id) for t_counter in t_counters]
        axes[1].plot(x, y, label=f"Task id {task_id}")
    axes[1].legend()
    axes[1].set_title("task_id")
    axes[1].set_xlabel("Batch index")
    axes[1].set_ylabel("Count in batch")

    plt.legend()