Пример #1
0
def get_multitask_experiment(name,
                             tasks,
                             slot,
                             shift,
                             data_dir="./store/datasets",
                             normalize=False,
                             augment=False,
                             only_config=False,
                             verbose=False,
                             exception=False,
                             only_test=False,
                             max_samples=None):
    '''Load, organize and return train- and test-dataset for requested multi-task experiment.'''

    ## NOTE: option 'normalize' and 'augment' only implemented for CIFAR-based experiments.

    # depending on experiment, get and organize the datasets
    if name == 'permMNIST':
        # configurations
        config = DATASET_CONFIGS['mnist']
        classes_per_task = 10
        if not only_config:
            # prepare dataset
            if not only_test:
                train_dataset = get_dataset('mnist',
                                            type="train",
                                            permutation=None,
                                            dir=data_dir,
                                            target_transform=None,
                                            verbose=verbose)
            test_dataset = get_dataset('mnist',
                                       type="test",
                                       permutation=None,
                                       dir=data_dir,
                                       target_transform=None,
                                       verbose=verbose)
            # generate permutations
            if exception:
                permutations = [None] + [
                    np.random.permutation(config['size']**2)
                    for _ in range(tasks - 1)
                ]
            else:
                permutations = [
                    np.random.permutation(config['size']**2)
                    for _ in range(tasks)
                ]
            # specify transformed datasets per task
            train_datasets = []
            test_datasets = []
            for task_id, perm in enumerate(permutations):
                target_transform = transforms.Lambda(
                    lambda y, x=task_id: y + x * classes_per_task)
                if not only_test:
                    train_datasets.append(
                        TransformedDataset(train_dataset,
                                           transform=transforms.Lambda(
                                               lambda x, p=perm:
                                               permutate_image_pixels(x, p)),
                                           target_transform=target_transform))
                test_datasets.append(
                    TransformedDataset(
                        test_dataset,
                        transform=transforms.Lambda(
                            lambda x, p=perm: permutate_image_pixels(x, p)),
                        target_transform=target_transform))
    elif name == 'splitMNIST':
        # check for number of tasks
        if tasks > 10:
            raise ValueError(
                "Experiment '{}' cannot have more than 10 tasks!".format(name))
        # configurations
        config = DATASET_CONFIGS['mnist28']
        classes_per_task = int(np.floor(10 / tasks))
        if not only_config:
            # prepare permutation to shuffle label-ids (to create different class batches for each random seed)
            permutation = np.array(list(
                range(10))) if exception else np.random.permutation(
                    list(range(10)))
            target_transform = transforms.Lambda(
                lambda y, p=permutation: int(p[y]))
            # prepare train and test datasets with all classes
            if not only_test:
                mnist_train = get_dataset('mnist28',
                                          type="train",
                                          dir=data_dir,
                                          target_transform=target_transform,
                                          verbose=verbose)
            mnist_test = get_dataset('mnist28',
                                     type="test",
                                     dir=data_dir,
                                     target_transform=target_transform,
                                     verbose=verbose)
            # generate labels-per-task
            labels_per_task = [
                list(
                    np.array(range(classes_per_task)) +
                    classes_per_task * task_id) for task_id in range(tasks)
            ]
            # split them up into sub-tasks
            train_datasets = []
            test_datasets = []
            for labels in labels_per_task:
                target_transform = None
                if not only_test:
                    train_datasets.append(
                        SubDataset(mnist_train,
                                   labels,
                                   target_transform=target_transform))
                test_datasets.append(
                    SubDataset(mnist_test,
                               labels,
                               target_transform=target_transform))
    elif name == 'CIFAR100':
        # check for number of tasks
        if tasks > 100:
            raise ValueError(
                "Experiment 'CIFAR100' cannot have more than 100 tasks!")
        # configurations
        config = DATASET_CONFIGS['cifar100']
        classes_per_task = int(np.floor(100 / tasks))
        if not only_config:
            # prepare permutation to shuffle label-ids (to create different class batches for each random seed)
            permutation = list(
                range(100))  #np.random.permutation(list(range(100)))
            target_transform = transforms.Lambda(
                lambda y, x=permutation: int(permutation[y]))
            # prepare train and test datasets with all classes
            if not only_test:
                cifar100_train = get_dataset('cifar100',
                                             shift=shift,
                                             slot=slot,
                                             type="train",
                                             dir=data_dir,
                                             normalize=normalize,
                                             augment=augment,
                                             target_transform=target_transform,
                                             verbose=verbose)
            cifar100_test = get_dataset('cifar100',
                                        shift=shift,
                                        slot=slot,
                                        type="test",
                                        dir=data_dir,
                                        normalize=normalize,
                                        target_transform=target_transform,
                                        verbose=verbose)
            # generate labels-per-task
            labels_per_task = [
                list(
                    np.array(range(classes_per_task)) +
                    classes_per_task * task_id) for task_id in range(tasks)
            ]
            # split them up into sub-tasks
            train_datasets = []
            test_datasets = []
            for labels in labels_per_task:
                target_transform = None
                if not only_test:
                    if max_samples is None:
                        train_datasets.append(
                            SubDataset(cifar100_train,
                                       labels,
                                       target_transform=target_transform))
                    else:
                        train_datasets.append(
                            ReducedSubDataset(
                                cifar100_train,
                                labels,
                                target_transform=target_transform,
                                max=max_samples))
                test_datasets.append(
                    SubDataset(cifar100_test,
                               labels,
                               target_transform=target_transform))
    else:
        raise RuntimeError('Given undefined experiment: {}'.format(name))

    # If needed, update number of (total) classes in the config-dictionary
    config['classes'] = classes_per_task * tasks
    config['normalize'] = normalize if name == 'CIFAR100' else False
    if config['normalize']:
        config['denormalize'] = AVAILABLE_TRANSFORMS["cifar100_denorm"]

    # Return tuple of train-, validation- and test-dataset, config-dictionary and number of classes per task
    return config if only_config else ((train_datasets, test_datasets), config,
                                       classes_per_task)
Пример #2
0
def prepare_datasets(name,
                     n_labels,
                     classes=True,
                     classes_per_task=None,
                     dir="./store/datasets",
                     verbose=False,
                     only_config=False,
                     exception=False,
                     only_test=False):
    '''Prepare training- and test-datasets for continual learning experiment.

    Args:
        name (str; `splitMNIST`|`splitCIFAR`):
        n_labels (int): number of classes or number of tasks/domains
        template_patterns (list, optional): required if ``name``=="artificial"
        noise_prob (float, optional): relevant if ``name``=="artificial",
                                        probability that entry of template-pattern is flipped (default: ``0.1``)
        classes (bool, optional): if ``True``, labels indicate classes, otherwise tasks/domains (default: ``True``)
        classes_per_task (int, optional): required if ``classes`` is ``False``
        dir (path, optional): where data is stored / should be downloaded
        verbose (bool, optional): if ``True``, print (additional) information to screen
        only_config (bool, optional): if ``True``, only return config-information (faster; data is not actually loaded)
        exception (bool, optional): if ``True``, do not shuffle labels
        only_test (bool, optional): if ``True``, only load and return test-set(s)

    Returns:
        tuple
    '''

    if name == 'splitMNIST':
        # Configurations.
        config = DATASET_CONFIGS['mnist28']
        if not only_config:
            # Prepare permutation to shuffle label-ids (to create different class batches for each random seed).
            n_class = config['classes']
            permutation = np.array(list(
                range(n_class))) if exception else np.random.permutation(
                    list(range(n_class)))
            target_transform = transforms.Lambda(
                lambda y, x=permutation: int(permutation[y]))
            # Load train and test datasets with all classes.
            if not only_test:
                mnist_train = get_dataset('mnist28',
                                          type="train",
                                          dir=dir,
                                          target_transform=target_transform,
                                          verbose=verbose)
            mnist_test = get_dataset('mnist28',
                                     type="test",
                                     dir=dir,
                                     target_transform=target_transform,
                                     verbose=verbose)
            # Generate labels-per-task.
            labels_per_task = range(n_labels) if classes else [
                list(
                    np.array(range(classes_per_task)) +
                    classes_per_task * task_id) for task_id in range(n_labels)
            ]
            # Split them up into separate datasets for each task / class.
            train_datasets = []
            test_datasets = []
            for labels in labels_per_task:
                target_transform = None if classes else transforms.Lambda(
                    lambda y, x=labels[0]: y - x)
                if (not only_test):
                    train_datasets.append(
                        SubDataset(mnist_train,
                                   labels,
                                   target_transform=target_transform))
                test_datasets.append(
                    SubDataset(mnist_test,
                               labels,
                               target_transform=target_transform))
    elif name == 'permMNIST':
        if classes:
            raise NotImplementedError(
                "Permuted MNIST with Class-IL is not (yet) implemented.")
        # Configurations
        config = DATASET_CONFIGS['mnist']
        if not only_config:
            # Generate labels-per-task
            labels_per_task = [
                list(
                    np.array(range(classes_per_task)) +
                    classes_per_task * task_id) for task_id in range(n_labels)
            ]
            # Prepare datasets
            train_dataset = get_dataset('mnist',
                                        type="train",
                                        dir=dir,
                                        verbose=verbose)
            test_dataset = get_dataset('mnist',
                                       type="test",
                                       dir=dir,
                                       verbose=verbose)
            # Generate permutations
            if exception:
                permutations = [None] + [
                    np.random.permutation(config['size']**2)
                    for _ in range(n_labels - 1)
                ]
            else:
                permutations = [
                    np.random.permutation(config['size']**2)
                    for _ in range(n_labels)
                ]
            # Prepare datasets per task
            train_datasets = []
            test_datasets = []
            for task_id, perm in enumerate(permutations):
                train_datasets.append(
                    TransformedDataset(
                        train_dataset,
                        transform=transforms.Lambda(
                            lambda x, p=perm: permutate_image_pixels(x, p)),
                    ))
                test_datasets.append(
                    TransformedDataset(
                        test_dataset,
                        transform=transforms.Lambda(
                            lambda x, p=perm: permutate_image_pixels(x, p)),
                    ))

    return config if only_config else ((train_datasets, test_datasets), config,
                                       labels_per_task)
Пример #3
0
def train_cl(model,
             train_datasets,
             replay_mode="none",
             rnt=None,
             classes_per_task=None,
             iters=2000,
             batch_size=32,
             batch_size_replay=None,
             loss_cbs=list(),
             eval_cbs=list(),
             reinit=False,
             args=None,
             only_last=False,
             use_exemplars=False,
             metric_cbs=list()):
    '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode].

    [model]             <nn.Module> main model to optimize across all tasks
    [train_datasets]    <list> with for each task the training <DataSet>
    [replay_mode]       <str>, choice from "current", "offline" and "none"
    [classes_per_task]  <int>, # classes per task; only 1st task has [classes_per_task]*[first_task_class_boost] classes
    [rnt]               <float>, indicating relative importance of new task (if None, relative to # old tasks)
    [iters]             <int>, # optimization-steps (=batches) per task; 1st task has [first_task_iter_boost] steps more
    [batch_size_replay] <int>, number of samples to replay per batch
    [only_last]         <bool>, only train on final task / episode
    [*_cbs]             <list> of call-back functions to evaluate training-progress'''

    # Should convolutional layers be frozen?
    freeze_convE = (utils.checkattr(args, "freeze_convE")
                    and hasattr(args, "depth") and args.depth > 0)

    # Use cuda?
    device = model._device()
    cuda = model._is_on_cuda()

    # Set default-values if not specified
    batch_size_replay = batch_size if batch_size_replay is None else batch_size_replay

    # Initiate indicators for replay (no replay for 1st task)
    Exact = Current = Offline_TaskIL = False
    previous_model = None

    # Register starting param-values (needed for "intelligent synapses").
    if isinstance(model, ContinualLearner) and model.si_c > 0:
        for n, p in model.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '__')
                model.register_buffer('{}_SI_prev_task'.format(n),
                                      p.detach().clone())

    # Loop over all tasks.
    for task, train_dataset in enumerate(train_datasets, 1):

        # In offline replay-setting, all tasks so far should be visited separately (i.e., separate data-loader per task)
        if replay_mode == "offline":
            Offline_TaskIL = True
            data_loader = [None] * task

        train_dataset = train_dataset

        # Initialize # iters left on data-loader(s)
        iters_left = 1 if (not Offline_TaskIL) else [1] * task
        if Exact:
            iters_left_previous = [1] * (task - 1)
            data_loader_previous = [None] * (task - 1)

        # Prepare <dicts> to store running importance estimates and parameter-values before update
        if isinstance(model, ContinualLearner) and model.si_c > 0:
            W = {}
            p_old = {}
            for n, p in model.named_parameters():
                if p.requires_grad:
                    n = n.replace('.', '__')
                    W[n] = p.data.clone().zero_()
                    p_old[n] = p.data.clone()

        # Find [active_classes] (=classes in current task)
        active_classes = [
            list(range(classes_per_task * i, classes_per_task * (i + 1)))
            for i in range(task)
        ]

        # Reinitialize the model's parameters and the optimizer (if requested)
        if reinit:
            from define_models import init_params
            init_params(model, args)
            model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))

        # Define a tqdm progress bar(s)
        progress = tqdm.tqdm(range(1, iters + 1))

        # Loop over all iterations
        iters_to_use = iters
        # -if only the final task should be trained on:
        if only_last and not task == len(train_datasets):
            iters_to_use = 0
        for batch_index in range(1, iters_to_use + 1):

            # Update # iters left on current data-loader(s) and, if needed, create new one(s)
            if not Offline_TaskIL:
                iters_left -= 1
                if iters_left == 0:
                    data_loader = iter(
                        utils.get_data_loader(train_dataset,
                                              batch_size,
                                              cuda=cuda,
                                              drop_last=True))
                    # NOTE:  [train_dataset]  is training-set of current task
                    #      [train_dataset] is training-set of current task with stored exemplars added (if requested)
                    iters_left = len(data_loader)
            else:
                # -with "offline replay", there is a separate data-loader for each task
                batch_size_to_use = batch_size
                for task_id in range(task):
                    iters_left[task_id] -= 1
                    if iters_left[task_id] == 0:
                        data_loader[task_id] = iter(
                            utils.get_data_loader(train_datasets[task_id],
                                                  batch_size_to_use,
                                                  cuda=cuda,
                                                  drop_last=True))
                        iters_left[task_id] = len(data_loader[task_id])

            # Update # iters left on data-loader(s) of the previous task(s) and, if needed, create new one(s)
            if Exact:
                up_to_task = task - 1
                batch_size_replay_pt = int(
                    np.floor(
                        batch_size_replay /
                        up_to_task)) if (up_to_task > 1) else batch_size_replay
                # -need separate replay for each task
                for task_id in range(up_to_task):
                    batch_size_to_use = min(batch_size_replay_pt,
                                            len(previous_datasets[task_id]))
                    iters_left_previous[task_id] -= 1
                    if iters_left_previous[task_id] == 0:
                        data_loader_previous[task_id] = iter(
                            utils.get_data_loader(previous_datasets[task_id],
                                                  batch_size_to_use,
                                                  cuda=cuda,
                                                  drop_last=True))
                        iters_left_previous[task_id] = len(
                            data_loader_previous[task_id])

            #-----------------Collect data------------------#

            #####-----CURRENT BATCH-----#####
            if not Offline_TaskIL:
                x, y = next(
                    data_loader)  #--> sample training data of current task
                y = y - classes_per_task * (
                    task - 1)  #--> ITL: adjust y-targets to 'active range'
                x, y = x.to(device), y.to(
                    device)  #--> transfer them to correct device
                #y = y.expand(1) if len(y.size())==1 else y     #--> hack for if batch-size is 1
            else:
                x = y = task_used = None  #--> all tasks are "treated as replay"
                # -sample training data for all tasks so far, move to correct device and store in lists
                x_, y_ = list(), list()
                for task_id in range(task):
                    x_temp, y_temp = next(data_loader[task_id])
                    x_.append(x_temp.to(device))
                    y_temp = y_temp - (
                        classes_per_task * task_id
                    )  #--> adjust y-targets to 'active range'
                    if batch_size_to_use == 1:
                        y_temp = torch.tensor([
                            y_temp
                        ])  #--> correct dimensions if batch-size is 1
                    y_.append(y_temp.to(device))

            #####-----REPLAYED BATCH-----#####
            if not Offline_TaskIL and not Exact and not Current:
                x_ = y_ = scores_ = task_used = None  #-> if no replay

            #--------------------------------------------INPUTS----------------------------------------------------#

            ##-->> Exact Replay <<--##
            if Exact:
                # Sample replayed training data, move to correct device and store in lists
                x_ = list()
                y_ = list()
                up_to_task = task - 1
                for task_id in range(up_to_task):
                    x_temp, y_temp = next(data_loader_previous[task_id])
                    x_.append(x_temp.to(device))
                    # -only keep [y_] if required (as otherwise unnecessary computations will be done)
                    if model.replay_targets == "hard":
                        y_temp = y_temp - (
                            classes_per_task * task_id
                        )  #-> adjust y-targets to 'active range'
                        y_.append(y_temp.to(device))
                    else:
                        y_.append(None)
                # If required, get target scores (i.e, [scores_])        -- using previous model, with no_grad()
                if (model.replay_targets == "soft") and (previous_model
                                                         is not None):
                    scores_ = list()
                    for task_id in range(up_to_task):
                        with torch.no_grad():
                            scores_temp = previous_model(x_[task_id])
                        scores_temp = scores_temp[:,
                                                  (classes_per_task *
                                                   task_id):(classes_per_task *
                                                             (task_id + 1))]
                        scores_.append(scores_temp)
                else:
                    scores_ = None

            ##-->> Current Replay <<--##
            if Current:
                x_ = x[:batch_size_replay]  #--> use current task inputs
                task_used = None

            #--------------------------------------------OUTPUTS----------------------------------------------------#

            if Current:
                # Get target scores & possibly labels (i.e., [scores_] / [y_]) -- use previous model, with no_grad()
                # -[x_] needs to be evaluated according to each previous task, so make list with entry per task
                scores_ = list()
                y_ = list()
                # -if no task-mask and no conditional generator, all scores can be calculated in one go
                if previous_model.mask_dict is None and not type(x_) == list:
                    with torch.no_grad():
                        all_scores_ = previous_model.classify(x_)
                for task_id in range(task - 1):
                    # -if there is a task-mask (i.e., XdG is used), obtain predicted scores for each task separately
                    if previous_model.mask_dict is not None:
                        previous_model.apply_XdGmask(task=task_id + 1)
                    if previous_model.mask_dict is not None or type(
                            x_) == list:
                        with torch.no_grad():
                            all_scores_ = previous_model.classify(
                                x_[task_id] if type(x_) == list else x_)
                    temp_scores_ = all_scores_[:, (classes_per_task *
                                                   task_id):(classes_per_task *
                                                             (task_id + 1))]
                    scores_.append(temp_scores_)
                    # - also get hard target
                    _, temp_y_ = torch.max(temp_scores_, dim=1)
                    y_.append(temp_y_)
            # -only keep predicted y_/scores_ if required (as otherwise unnecessary computations will be done)
            y_ = y_ if (model.replay_targets == "hard") else None
            scores_ = scores_ if (model.replay_targets == "soft") else None

            #-----------------Train model------------------#

            # Train the main model with this batch
            loss_dict = model.train_a_batch(x,
                                            y=y,
                                            x_=x_,
                                            y_=y_,
                                            scores_=scores_,
                                            tasks_=task_used,
                                            active_classes=active_classes,
                                            task=task,
                                            rnt=(1. if task == 1 else 1. /
                                                 task) if rnt is None else rnt,
                                            freeze_convE=freeze_convE)

            # Update running parameter importance estimates in W
            if isinstance(model, ContinualLearner) and model.si_c > 0:
                for n, p in model.named_parameters():
                    if p.requires_grad:
                        n = n.replace('.', '__')
                        if p.grad is not None:
                            W[n].add_(-p.grad * (p.detach() - p_old[n]))
                        p_old[n] = p.detach().clone()

            # Fire callbacks (for visualization of training-progress / evaluating performance after each task)
            for loss_cb in loss_cbs:
                if loss_cb is not None:
                    loss_cb(progress, batch_index, loss_dict, task=task)
            for eval_cb in eval_cbs:
                if eval_cb is not None:
                    eval_cb(model, batch_index, task=task)

        # Close progres-bar
        progress.close()

        ##----------> UPON FINISHING EACH TASK...

        # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty
        if isinstance(model, ContinualLearner) and model.ewc_lambda > 0:
            # -find allowed classes
            allowed_classes = list(
                range(classes_per_task * (task - 1), classes_per_task * task))
            # -if needed, apply correct task-specific mask
            if model.mask_dict is not None:
                model.apply_XdGmask(task=task)
            # -estimate FI-matrix
            model.estimate_fisher(train_dataset,
                                  allowed_classes=allowed_classes)

        # SI: calculate and update the normalized path integral
        if isinstance(model, ContinualLearner) and model.si_c > 0:
            model.update_omega(W, model.epsilon)

        # EXEMPLARS: update exemplar sets
        if use_exemplars or replay_mode == "exemplars":
            exemplars_per_class = int(
                np.floor(model.memory_budget / (classes_per_task * task)))
            # reduce examplar-sets
            model.reduce_exemplar_sets(exemplars_per_class)
            # for each new class trained on, construct examplar-set
            new_classes = list(
                range(classes_per_task * (task - 1), classes_per_task * task))
            for class_id in new_classes:
                # create new dataset containing only all examples of this class
                class_dataset = SubDataset(original_dataset=train_dataset,
                                           sub_labels=[class_id])
                # based on this dataset, construct new exemplar-set for this class
                model.construct_exemplar_set(dataset=class_dataset,
                                             n=exemplars_per_class)
            model.compute_means = True

        # Calculate statistics required for metrics
        for metric_cb in metric_cbs:
            if metric_cb is not None:
                metric_cb(model, iters, task=task)

        # REPLAY: update source for replay
        previous_model = copy.deepcopy(model).eval()
        if replay_mode == 'current':
            Current = True
        elif replay_mode in ('exemplars', 'exact'):
            Exact = True
            if replay_mode == "exact":
                previous_datasets = train_datasets[:task]
            else:
                previous_datasets = []
                for task_id in range(task):
                    previous_datasets.append(
                        ExemplarDataset(
                            model.exemplar_sets[(classes_per_task *
                                                 task_id):(classes_per_task *
                                                           (task_id + 1))],
                            target_transform=lambda y, x=classes_per_task *
                            task_id: y + x))
Пример #4
0
def get_experiment(name,
                   tasks=1,
                   data_dir="./store/datasets",
                   normalize=False,
                   augment=False,
                   verbose=False,
                   exception=False,
                   only_config=False,
                   per_class=False):
    '''Load, organize and return train- and test-dataset(s) for requested experiment.'''

    # Define data-type
    if name == "MNIST":
        data_type = 'mnist'
    elif name == "CIFAR10":
        data_type = 'cifar10'
    elif name == "CIFAR100":
        data_type = 'cifar100'
    elif name == "CORe50":
        data_type = 'core50'
    else:
        raise ValueError('Given undefined experiment: {}'.format(name))

    # Get config-dict
    config = DATASET_CONFIGS[data_type].copy()
    config['normalize'] = normalize
    if normalize:
        config['denormalize'] = AVAILABLE_TRANSFORMS[data_type + "_denorm"]
    # check for number of tasks
    if tasks > config['classes']:
        raise ValueError(
            "Experiment '{}' cannot have more than {} tasks!".format(
                name, config['classes']))
    # -how many classes per epoch?
    if not per_class:
        classes_per_task = int(np.floor(config['classes'] / tasks))
        config['classes'] = classes_per_task * tasks
        config['classes_per_task'] = classes_per_task
    # -if only config-dict is needed, return it
    if only_config:
        return config

    # Prepare permutation to shuffle label-ids (to create different class batches for each random seed)
    classes = config['classes']
    permuted_class_list = np.array(list(
        range(classes))) if exception else np.random.permutation(
            list(range(classes)))

    # Load train and test datasets with all classes
    if not name in ("CORe50"):
        target_transform = transforms.Lambda(
            lambda y, p=permuted_class_list: int(p[y]))
        trainset = get_dataset(data_type,
                               type="train",
                               dir=data_dir,
                               target_transform=target_transform,
                               normalize=normalize,
                               augment=augment,
                               verbose=verbose)
        testset = get_dataset(data_type,
                              type="test",
                              dir=data_dir,
                              target_transform=target_transform,
                              normalize=normalize,
                              augment=augment,
                              verbose=verbose)

    # Split the testset, and possible also the trainset, up into separate datasets for each task/class
    labels_per_task = [[label] for label in range(classes)] if per_class else [
        list(np.array(range(classes_per_task)) + classes_per_task * task_id)
        for task_id in range(tasks)
    ]
    train_datasets = []
    test_datasets = []
    for labels in labels_per_task:
        if name in ("CORe50"):
            # -training data
            class_datasets = []
            for label in labels:
                class_id = permuted_class_list[label]
                class_ids = list(
                    range(class_id * 5, (class_id + 1) *
                          5))  #-> for each category, there are 5 objects
                object_datasets = []
                for class_id in class_ids:
                    path_name = os.path.join(
                        data_dir, 'core50_features',
                        'train_object{}.pt'.format(class_id))
                    feature_tensor = torch.load(path_name)
                    object_datasets.append(
                        FeatureDataset(feature_tensor.view(-1, 512, 1, 1),
                                       label))
                class_datasets.append(ConcatDataset(object_datasets))
            task_dataset = ConcatDataset(class_datasets)
            train_datasets.append(task_dataset)
            # -test data
            class_datasets = []
            for label in labels:
                class_id = permuted_class_list[label]
                class_ids = list(
                    range(class_id * 5, (class_id + 1) *
                          5))  #-> for each category, there are 5 objects
                object_datasets = []
                for class_id in class_ids:
                    path_name = os.path.join(
                        data_dir, 'core50_features',
                        'test_object{}.pt'.format(class_id))
                    feature_tensor = torch.load(path_name)
                    object_datasets.append(
                        FeatureDataset(feature_tensor.view(-1, 512, 1, 1),
                                       label))
                class_datasets.append(ConcatDataset(object_datasets))
            task_dataset = ConcatDataset(class_datasets)
            test_datasets.append(task_dataset)
        else:
            train_datasets.append(
                SubDataset(trainset, labels, target_transform=None))
            test_datasets.append(
                SubDataset(testset, labels, target_transform=None))

    # Return tuple of data-sets and config-dictionary
    return ((train_datasets, test_datasets), config)