Пример #1
0
 def _make_val_dataset(self) -> Dataset:
     if self.smooth_task_boundaries:
         return smooth_task_boundaries_concat(
             self.val_datasets, seed=self.config.seed
         )
     if self.stationary_context:
         joined_dataset = concat(self.val_datasets)
         return shuffle(joined_dataset, seed=self.config.seed)
     if self.known_task_boundaries_at_train_time:
         return self.val_datasets[self.current_task_id]
     return concat(self.val_datasets)
Пример #2
0
 def _make_train_dataset(self) -> Union[TaskSet, Dataset]:
     # NOTE: Passing the same seed to `train`/`valid`/`test` is fine, because it's
     # only used for the shuffling used to make the task boundaries smooth.
     if self.smooth_task_boundaries:
         return smooth_task_boundaries_concat(
             self.train_datasets, seed=self.config.seed if self.config else None
         )
     if self.stationary_context:
         joined_dataset = concat(self.train_datasets)
         return shuffle(joined_dataset, seed=self.config.seed)
     if self.known_task_boundaries_at_train_time:
         return self.train_datasets[self.current_task_id]
     else:
         return concat(self.train_datasets)
Пример #3
0
 def _make_test_dataset(self) -> Dataset:
     if self.smooth_task_boundaries:
         return smooth_task_boundaries_concat(
             self.test_datasets, seed=self.config.seed
         )
     else:
         return concat(self.test_datasets)
Пример #4
0
def test_concat_function(nb_others):
    x = np.random.rand(10, 2, 2, 3)
    y = np.ones((10, ))
    t = np.ones((10, ))

    task_sets = [
        TaskSet(np.copy(x), np.copy(y), np.copy(t), None)
        for _ in range(nb_others)
    ]

    concatenation = concat(task_sets)
    assert len(concatenation) == nb_others * 10
    loader = DataLoader(concatenation)
    for x, y, t in loader:
        pass
Пример #5
0
def smooth_task_boundaries_concat(
    datasets: List[Dataset], seed: int = None, window_length: float = 0.03
) -> ConcatDataset:
    """ TODO: Use a smarter way of mixing from one to the other? """
    lengths = [len(dataset) for dataset in datasets]
    total_length = sum(lengths)
    n_tasks = len(datasets)

    if not isinstance(window_length, int):
        window_length = int(total_length * window_length)
    assert (
        window_length > 1
    ), f"Window length should be positive or a fraction of the dataset length. ({window_length})"

    rng = np.random.default_rng(seed)

    def option1():
        shuffled_indices = np.arange(total_length)
        for start_index in range(
            0, total_length - window_length + 1, window_length // 2
        ):
            rng.shuffle(shuffled_indices[start_index : start_index + window_length])
        return shuffled_indices

    # Maybe do the same but backwards?

    # IDEA #2: Sample based on how close to the 'center' of the task we are.
    def option2():
        boundaries = np.array(list(itertools.accumulate(lengths, initial=0)))
        middles = [
            (start + end) / 2 for start, end in zip(boundaries[0:], boundaries[1:])
        ]
        samples_left: Dict[int, int] = {i: length for i, length in enumerate(lengths)}
        indices_left: Dict[int, List[int]] = {
            i: list(range(boundaries[i], boundaries[i] + length))
            for i, length in enumerate(lengths)
        }

        out_indices: List[int] = []
        last_dataset_index = n_tasks - 1
        for step in range(total_length):
            if step < middles[0] and samples_left[0]:
                # Prevent sampling things from task 1 at the beginning of task 0, and
                eligible_dataset_ids = [0]
            elif step > middles[-1] and samples_left[last_dataset_index]:
                # Prevent sampling things from task N-1 at the emd of task N
                eligible_dataset_ids = [last_dataset_index]
            else:
                # 'smooth', but at the boundaries there are actually two or three datasets,
                # from future tasks even!
                eligible_dataset_ids = list(k for k, v in samples_left.items() if v > 0)
                # if len(eligible_dataset_ids) > 2:
                #     # Prevent sampling from future tasks (past the next task) when at a
                #     # boundary.
                #     left_dataset_index = min(eligible_dataset_ids)
                #     right_dataset_index = min(
                #         v for v in eligible_dataset_ids if v > left_dataset_index
                #     )
                #     eligible_dataset_ids = [left_dataset_index, right_dataset_index]

            options = np.array(eligible_dataset_ids, dtype=int)

            # Calculate the 'distance' to the center of the task's dataset.
            distances = np.abs(
                [step - middles[dataset_index] for dataset_index in options]
            )

            # NOTE: THis exponent is kindof arbitrary, setting it to this value because it
            # sortof works for MNIST so far.
            probs = 1 / (1 + np.abs(distances) ** 2)
            probs /= sum(probs)

            chosen_dataset = rng.choice(options, p=probs)
            chosen_index = indices_left[chosen_dataset].pop()
            samples_left[chosen_dataset] -= 1
            out_indices.append(chosen_index)

        shuffled_indices = np.array(out_indices)
        return shuffled_indices

    def option3():
        shuffled_indices = np.arange(total_length)
        for start_index in range(
            0, total_length - window_length + 1, window_length // 2
        ):
            rng.shuffle(shuffled_indices[start_index : start_index + window_length])
        for start_index in reversed(
            range(0, total_length - window_length + 1, window_length // 2)
        ):
            rng.shuffle(shuffled_indices[start_index : start_index + window_length])
        return shuffled_indices

    shuffled_indices = option3()

    if all(isinstance(dataset, TaskSet) for dataset in datasets):
        # Use the 'concat' from continuum, just to preserve the field/methods of a
        # TaskSet.
        joined_taskset = concat(datasets)
        return subset(joined_taskset, shuffled_indices)
    else:
        joined_dataset = ConcatDataset(datasets)
        return Subset(joined_dataset, shuffled_indices)

    return shuffled_indices
Пример #6
0
def test_shuffle(config: Config):
    dataset = MNIST(data_path=config.data_dir, train=True)
    cl_dataset = concat(ClassIncremental(dataset, increment=2))
    shuffled_dataset = shuffle(cl_dataset)
    assert (shuffled_dataset._y != cl_dataset._y).sum() > len(cl_dataset) / 2
    assert (shuffled_dataset._t != cl_dataset._t).sum() > len(cl_dataset) / 2