コード例 #1
0
    def __init__(self, model: Module, optimizer: Optimizer, criterion,
                 mem_size: int = 200,
                 train_mb_size: int = 1, train_epochs: int = 1,
                 eval_mb_size: int = None, device=None,
                 plugins: Optional[List[StrategyPlugin]] = None,
                 evaluator: EvaluationPlugin = default_logger):
        """ Experience replay strategy. See ReplayPlugin for more details.
        This strategy does not use task identities.

        :param model: The model.
        :param optimizer: The optimizer to use.
        :param criterion: The loss criterion to use.
        :param mem_size: replay buffer size.
        :param train_mb_size: The train minibatch size. Defaults to 1.
        :param train_epochs: The number of training epochs. Defaults to 1.
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
        :param device: The device to use. Defaults to None (cpu).
        :param plugins: Plugins to be added. Defaults to None.
        :param evaluator: (optional) instance of EvaluationPlugin for logging
            and metric computations.
        """
        rp = ReplayPlugin(mem_size)
        if plugins is None:
            plugins = [rp]
        else:
            plugins.append(rp)
        super().__init__(
            model, optimizer, criterion,
            train_mb_size=train_mb_size, train_epochs=train_epochs,
            eval_mb_size=eval_mb_size, device=device, plugins=plugins,
            evaluator=evaluator)
コード例 #2
0
    def assert_balancing(self, policy):
        benchmark = get_fast_benchmark(use_task_labels=True)
        replay = ReplayPlugin(mem_size=100, storage_policy=policy)
        model = SimpleMLP(num_classes=benchmark.n_classes)

        # CREATE THE STRATEGY INSTANCE (NAIVE)
        cl_strategy = Naive(
            model,
            SGD(model.parameters(), lr=0.001),
            CrossEntropyLoss(),
            train_mb_size=100,
            train_epochs=0,
            eval_mb_size=100,
            plugins=[replay],
            evaluator=None,
        )

        for exp in benchmark.train_stream:
            cl_strategy.train(exp)

            ext_mem = policy.buffer_groups
            ext_mem_data = policy.buffer_datasets
            print(list(ext_mem.keys()), [len(el) for el in ext_mem_data])

            # buffer size should equal self.mem_size if data is large enough
            len_tot = sum([len(el) for el in ext_mem_data])
            assert len_tot == policy.max_size
コード例 #3
0
def get_method_plugins(args: Namespace):
    if args.method_name == "naive":
        return []
    elif args.method_name == "replay_avalanche":
        return [ReplayPlugin()]
    elif args.method_name == "replay":
        memory = CILMemory(args.memory_size)
        return [CILMemoryPlugin(memory, HerdingMemoryStrategy()), CILReplayPlugin(memory)]
    elif args.method_name == "hybrid1":
        memory = CILMemory(args.memory_size)
        return [CILMemoryPlugin(memory, HerdingMemoryStrategy()), CILReplayPlugin(memory), LwFMCPlugin()]
    elif args.method_name == "iCaRL":
        return make_icarl_plugins(args.memory_size)
    elif args.method_name == "tricicl-P-ND":
        return make_tricicl_post_training_plugins(args.memory_size, distillation=False)
    elif args.method_name == "tricicl-P-D":
        return make_tricicl_post_training_plugins(args.memory_size, distillation=True)
    elif args.method_name == "tricicl-B-D":
        return make_tricicl_pre_training_plugins(args.memory_size, pre_distillation=True)
    elif args.method_name == "tricicl-A-D":
        return make_tricicl_alternate_training_plugins(args.memory_size, distillation=True)
    elif args.method_name == "tricicl-P-D-NME":
        return make_tricicl_post_training_plugins(args.memory_size, distillation=True, nme=True)
    elif args.method_name == "tricicl-B-D-NME":
        return make_tricicl_pre_training_plugins(args.memory_size, pre_distillation=True, nme=True)
    elif args.method_name == "tricicl-A-D-NME":
        return make_tricicl_alternate_training_plugins(args.memory_size, distillation=True, nme=True)
    # elif args.method_name == "tricicl":
    #     return make_tricicl_during_training_plugin(args.memory_size, use_replay=True, nme=True, distillation=False)

    raise ValueError(f"Method {args.method_name} not supported")
コード例 #4
0
    def _test_replay_balanced_memory(self, storage_policy, mem_size):
        benchmark = get_fast_benchmark(use_task_labels=True)
        model = SimpleMLP(input_size=6, hidden_size=10)
        replayPlugin = ReplayPlugin(
            mem_size=mem_size, storage_policy=storage_policy
        )
        cl_strategy = Naive(
            model,
            SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001),
            CrossEntropyLoss(),
            train_mb_size=32,
            train_epochs=1,
            eval_mb_size=100,
            plugins=[replayPlugin],
        )

        n_seen_data = 0
        for step in benchmark.train_stream:
            n_seen_data += len(step.dataset)
            mem_fill = min(mem_size, n_seen_data)
            cl_strategy.train(step)
            lengths = []
            for d in replayPlugin.storage_policy.buffer_datasets:
                lengths.append(len(d))
            self.assertEqual(sum(lengths), mem_fill)  # Always fully filled
コード例 #5
0
    def test_replay_balanced_memory(self):
        scenario = self.create_scenario(task_labels=True)
        mem_size = 25
        model = SimpleMLP(input_size=6, hidden_size=10)
        replayPlugin = ReplayPlugin(mem_size=mem_size)
        cl_strategy = Naive(model,
                            SGD(model.parameters(),
                                lr=0.001,
                                momentum=0.9,
                                weight_decay=0.001),
                            CrossEntropyLoss(),
                            train_mb_size=32,
                            train_epochs=1,
                            eval_mb_size=100,
                            plugins=[replayPlugin])

        for step in scenario.train_stream:
            curr_mem_size = min(mem_size, len(step.dataset))
            cl_strategy.train(step)
            ext_mem = replayPlugin.ext_mem
            lengths = []
            for task_id in ext_mem.keys():
                lengths.append(len(ext_mem[task_id]))
            self.assertEqual(sum(lengths), curr_mem_size)
            difference = max(lengths) - min(lengths)
            self.assertLessEqual(difference, 1)
コード例 #6
0
ファイル: replay.py プロジェクト: ryanlindeborg/avalanche
def main(args):
    # --- CONFIG
    device = torch.device(f"cuda:{args.cuda}"
                          if torch.cuda.is_available() and
                          args.cuda >= 0 else "cpu")
    n_batches = 5
    # ---------

    # --- TRANSFORMATIONS
    train_transform = transforms.Compose([
        RandomCrop(28, padding=4),
        ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    test_transform = transforms.Compose([
        ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # ---------

    # --- SCENARIO CREATION
    mnist_train = MNIST('./data/mnist', train=True,
                        download=True, transform=train_transform)
    mnist_test = MNIST('./data/mnist', train=False,
                       download=True, transform=test_transform)
    scenario = nc_scenario(
        mnist_train, mnist_test, n_batches, task_labels=False, seed=1234)
    # ---------

    # MODEL CREATION
    model = SimpleMLP(num_classes=scenario.n_classes)

    # choose some metrics and evaluation method
    interactive_logger = InteractiveLogger()

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(
            minibatch=True, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        ExperienceForgetting(),
        loggers=[interactive_logger])

    # CREATE THE STRATEGY INSTANCE (NAIVE)
    cl_strategy = Naive(model, torch.optim.Adam(model.parameters(), lr=0.001),
                        CrossEntropyLoss(),
                        train_mb_size=100, train_epochs=4, eval_mb_size=100, device=device,
                        plugins=[ReplayPlugin(mem_size=10000)],
                        evaluator=eval_plugin
                        )

    # TRAINING LOOP
    print('Starting experiment...')
    results = []
    for experience in scenario.train_stream:
        print("Start of experience ", experience.current_experience)
        cl_strategy.train(experience)
        print('Training completed')

        print('Computing accuracy on the whole test set')
        results.append(cl_strategy.eval(scenario.test_stream))
コード例 #7
0
def main(args):

    # Model getter: specify dataset and depth of the network.
    model = pytorchcv_wrapper.resnet('cifar10', depth=20, pretrained=False)

    # Or get a more specific model. E.g. wide resnet, with depth 40 and growth
    # factor 8 for Cifar 10.
    # model = pytorchcv_wrapper.get_model("wrn40_8_cifar10", pretrained=False)

    # --- CONFIG
    device = torch.device(f"cuda:{args.cuda}"
                          if torch.cuda.is_available() and
                          args.cuda >= 0 else "cpu")

    device = "cpu"

    # --- TRANSFORMATIONS
    transform = transforms.Compose([
        ToTensor(),
        transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))
    ])

    # --- SCENARIO CREATION
    cifar_train = CIFAR10(root=expanduser("~") + "/.avalanche/data/cifar10/",
                          train=True, download=True, transform=transform)
    cifar_test = CIFAR10(root=expanduser("~") + "/.avalanche/data/cifar10/",
                         train=False, download=True, transform=transform)
    scenario = nc_benchmark(
        cifar_train, cifar_test, 5, task_labels=False, seed=1234,
        fixed_class_order=[i for i in range(10)])

    # choose some metrics and evaluation method
    interactive_logger = InteractiveLogger()

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(
            minibatch=True, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        forgetting_metrics(experience=True),
        loggers=[interactive_logger])

    # CREATE THE STRATEGY INSTANCE (Naive, with Replay)
    cl_strategy = Naive(model, torch.optim.SGD(model.parameters(), lr=0.01),
                        CrossEntropyLoss(),
                        train_mb_size=100, train_epochs=1, eval_mb_size=100,
                        device=device,
                        plugins=[ReplayPlugin(mem_size=1000)],
                        evaluator=eval_plugin
                        )

    # TRAINING LOOP
    print('Starting experiment...')
    results = []
    for experience in scenario.train_stream:
        print("Start of experience ", experience.current_experience)
        cl_strategy.train(experience)
        print('Training completed')

        print('Computing accuracy on the whole test set')
        results.append(cl_strategy.eval(scenario.test_stream))
コード例 #8
0
    def test_dataload_batch_balancing(self):
        scenario = get_fast_scenario()
        model = SimpleMLP(input_size=6, hidden_size=10)
        batch_size = 32
        replayPlugin = ReplayPlugin(mem_size=20)
        cl_strategy = Naive(model,
                            SGD(model.parameters(),
                                lr=0.001,
                                momentum=0.9,
                                weight_decay=0.001),
                            CrossEntropyLoss(),
                            train_mb_size=batch_size,
                            train_epochs=1,
                            eval_mb_size=100,
                            plugins=[replayPlugin])

        for step in scenario.train_stream:
            adapted_dataset = step.dataset
            dataloader = MultiTaskJoinedBatchDataLoader(
                adapted_dataset,
                AvalancheConcatDataset(replayPlugin.ext_mem.values()),
                oversample_small_tasks=True,
                num_workers=0,
                batch_size=batch_size,
                shuffle=True)

            for mini_batch in dataloader:
                lengths = []
                for task_id in mini_batch.keys():
                    lengths.append(len(mini_batch[task_id][1]))
                if sum(lengths) == batch_size:
                    difference = max(lengths) - min(lengths)
                    self.assertLessEqual(difference, 1)
                self.assertLessEqual(sum(lengths), batch_size)
            cl_strategy.train(step)
コード例 #9
0
def main(args):
    # --- CONFIG
    device = torch.device(f"cuda:{args.cuda}"
                          if torch.cuda.is_available() and
                          args.cuda >= 0 else "cpu")
    # ---------

    # --- TRANSFORMATIONS
    train_transform = transforms.Compose([
        RandomCrop(28, padding=4),
        ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    test_transform = transforms.Compose([
        ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # ---------

    # --- SCENARIO CREATION
    mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
                        train=True, download=True, transform=train_transform)
    mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
                       train=False, download=True, transform=test_transform)
    scenario = nc_scenario(
        mnist_train, mnist_test, 5, task_labels=False, seed=1234)
    # ---------

    # MODEL CREATION
    model = SimpleMLP(num_classes=scenario.n_classes)

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(epoch=True, experience=True, stream=True),
        loss_metrics(epoch=True, experience=True, stream=True),
        # save image should be False to appropriately view
        # results in Interactive Logger.
        # a tensor will be printed
        StreamConfusionMatrix(save_image=False, normalize='all'),
        loggers=InteractiveLogger()
    )

    # CREATE THE STRATEGY INSTANCE (NAIVE)
    cl_strategy = Naive(
        model, SGD(model.parameters(), lr=0.001, momentum=0.9),
        CrossEntropyLoss(), train_mb_size=100, train_epochs=4, eval_mb_size=100,
        device=device, evaluator=eval_plugin, plugins=[ReplayPlugin(5000)])

    # TRAINING LOOP
    print('Starting experiment...')
    results = []
    for experience in scenario.train_stream:
        print("Start of experience: ", experience.current_experience)
        print("Current Classes: ", experience.classes_in_this_experience)

        cl_strategy.train(experience)
        print('Training completed')

        print('Computing accuracy on the whole test set')
        results.append(cl_strategy.eval(scenario.test_stream))
コード例 #10
0
    def __init__(
        self,
        model: Module,
        optimizer: Optimizer,
        criterion,
        mem_size: int = 200,
        train_mb_size: int = 1,
        train_epochs: int = 1,
        eval_mb_size: int = None,
        device=None,
        plugins: Optional[List[SupervisedPlugin]] = None,
        evaluator: EvaluationPlugin = default_evaluator,
        eval_every=-1,
        **base_kwargs
    ):
        """Init.

        :param model: The model.
        :param optimizer: The optimizer to use.
        :param criterion: The loss criterion to use.
        :param mem_size: replay buffer size.
        :param train_mb_size: The train minibatch size. Defaults to 1.
        :param train_epochs: The number of training epochs. Defaults to 1.
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
        :param device: The device to use. Defaults to None (cpu).
        :param plugins: Plugins to be added. Defaults to None.
        :param evaluator: (optional) instance of EvaluationPlugin for logging
            and metric computations.
        :param eval_every: the frequency of the calls to `eval` inside the
            training loop. -1 disables the evaluation. 0 means `eval` is called
            only at the end of the learning experience. Values >0 mean that
            `eval` is called every `eval_every` epochs and at the end of the
            learning experience.
        :param **base_kwargs: any additional
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
        """

        rp = ReplayPlugin(mem_size)
        if plugins is None:
            plugins = [rp]
        else:
            plugins.append(rp)
        super().__init__(
            model,
            optimizer,
            criterion,
            train_mb_size=train_mb_size,
            train_epochs=train_epochs,
            eval_mb_size=eval_mb_size,
            device=device,
            plugins=plugins,
            evaluator=evaluator,
            eval_every=eval_every,
            **base_kwargs
        )
コード例 #11
0
 def test_warning_slda_lwf(self):
     model, _, criterion, my_nc_benchmark = self.init_sit()
     with self.assertWarns(Warning) as cm:
         StreamingLDA(
             model,
             criterion,
             input_size=10,
             output_layer_name="features",
             num_classes=10,
             plugins=[LwFPlugin(), ReplayPlugin()],
         )
コード例 #12
0
 def test_warning_slda_lwf(self):
     model, _, criterion, my_nc_benchmark = self.init_sit()
     with self.assertLogs('avalanche.training.strategies', "WARNING") as cm:
         StreamingLDA(model, criterion, input_size=10,
                      output_layer_name='features', num_classes=10,
                      plugins=[LwFPlugin(), ReplayPlugin()])
     self.assertEqual(1, len(cm.output))
     self.assertIn(
         "LwFPlugin seems to use the callback before_backward"
         " which is disabled by StreamingLDA",
         cm.output[0]
     )
コード例 #13
0
ファイル: strategy_wrappers.py プロジェクト: gab709/avalanche
    def __init__(self,
                 model: Module,
                 optimizer: Optimizer,
                 criterion,
                 mem_size: int = 200,
                 train_mb_size: int = 1,
                 train_epochs: int = 1,
                 eval_mb_size: int = None,
                 device=None,
                 plugins: Optional[List[StrategyPlugin]] = None,
                 evaluator: EvaluationPlugin = default_logger,
                 eval_every=-1):
        """ Experience replay strategy. See ReplayPlugin for more details.
        This strategy does not use task identities.

        :param model: The model.
        :param optimizer: The optimizer to use.
        :param criterion: The loss criterion to use.
        :param mem_size: replay buffer size.
        :param train_mb_size: The train minibatch size. Defaults to 1.
        :param train_epochs: The number of training epochs. Defaults to 1.
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
        :param device: The device to use. Defaults to None (cpu).
        :param plugins: Plugins to be added. Defaults to None.
        :param evaluator: (optional) instance of EvaluationPlugin for logging
            and metric computations.
        :param eval_every: the frequency of the calls to `eval` inside the
            training loop.
                if -1: no evaluation during training.
                if  0: calls `eval` after the final epoch of each training
                    experience.
                if >0: calls `eval` every `eval_every` epochs and at the end
                    of all the epochs for a single experience.
        """

        rp = ReplayPlugin(mem_size)
        if plugins is None:
            plugins = [rp]
        else:
            plugins.append(rp)
        super().__init__(model,
                         optimizer,
                         criterion,
                         train_mb_size=train_mb_size,
                         train_epochs=train_epochs,
                         eval_mb_size=eval_mb_size,
                         device=device,
                         plugins=plugins,
                         evaluator=evaluator,
                         eval_every=eval_every)
コード例 #14
0
    def test_dataload_reinit(self):
        scenario = get_fast_scenario()
        model = SimpleMLP(input_size=6, hidden_size=10)

        replayPlugin = ReplayPlugin(mem_size=5)
        cl_strategy = Naive(
            model,
            SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001),
            CrossEntropyLoss(), train_mb_size=16, train_epochs=1,
            eval_mb_size=16,
            plugins=[replayPlugin]
        )
        for step in scenario.train_stream[:2]:
            cl_strategy.train(step)
コード例 #15
0
def main(cuda: int):
    # --- CONFIG
    device = torch.device(
        f"cuda:{cuda}" if torch.cuda.is_available() else "cpu"
    )
    # --- SCENARIO CREATION
    scenario = SplitCIFAR10(n_experiences=2, seed=42)
    # ---------

    # MODEL CREATION
    model = SimpleMLP(num_classes=scenario.n_classes, input_size=196608 // 64)

    # choose some metrics and evaluation method
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(stream=True, experience=True),
        images_samples_metrics(
            on_train=True,
            on_eval=True,
            n_cols=10,
            n_rows=10,
        ),
        labels_repartition_metrics(
            # image_creator=repartition_bar_chart_image_creator,
            on_train=True,
            on_eval=True,
        ),
        loggers=[
            TensorboardLogger(f"tb_data/{datetime.now()}"),
            InteractiveLogger(),
        ],
    )

    # CREATE THE STRATEGY INSTANCE (NAIVE)
    cl_strategy = Naive(
        model,
        Adam(model.parameters()),
        train_mb_size=128,
        train_epochs=1,
        eval_mb_size=128,
        device=device,
        plugins=[ReplayPlugin(mem_size=1_000)],
        evaluator=eval_plugin,
    )

    # TRAINING LOOP
    for i, experience in enumerate(scenario.train_stream, 1):
        cl_strategy.train(experience)
        cl_strategy.eval(scenario.test_stream[:i])
コード例 #16
0
    def test_dataload_batch_balancing(self):
        benchmark = get_fast_benchmark()
        batch_size = 32
        replayPlugin = ReplayPlugin(mem_size=20)

        model = SimpleMLP(input_size=6, hidden_size=10)
        cl_strategy = Naive(
            model,
            SGD(model.parameters(), lr=0.001, momentum=0.9,
                weight_decay=0.001),
            CrossEntropyLoss(),
            train_mb_size=batch_size,
            train_epochs=1,
            eval_mb_size=100,
            plugins=[replayPlugin],
        )
        for step in benchmark.train_stream:
            adapted_dataset = step.dataset
            if len(replayPlugin.storage_policy.buffer) > 0:
                dataloader = ReplayDataLoader(
                    adapted_dataset,
                    replayPlugin.storage_policy.buffer,
                    oversample_small_tasks=True,
                    num_workers=0,
                    batch_size=batch_size,
                    shuffle=True,
                )
            else:
                dataloader = TaskBalancedDataLoader(adapted_dataset)

            for mini_batch in dataloader:
                mb_task_labels = mini_batch[-1]
                lengths = []
                for task_id in adapted_dataset.task_set:
                    len_task = (mb_task_labels == task_id).sum()
                    lengths.append(len_task)
                if sum(lengths) == batch_size:
                    difference = max(lengths) - min(lengths)
                    self.assertLessEqual(difference, 1)
                self.assertLessEqual(sum(lengths), batch_size)
            cl_strategy.train(step)
コード例 #17
0
ファイル: test_plugins.py プロジェクト: bbeatrix/avalanche
    def _test_replay_balanced_memory(self, storage_policy, mem_size):
        scenario = self.create_scenario(task_labels=True)
        model = SimpleMLP(input_size=6, hidden_size=10)
        replayPlugin = ReplayPlugin(mem_size=mem_size,
                                    storage_policy=storage_policy)
        cl_strategy = Naive(
            model,
            SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001),
            CrossEntropyLoss(), train_mb_size=32, train_epochs=1,
            eval_mb_size=100, plugins=[replayPlugin]
        )

        n_seen_data = 0
        for step in scenario.train_stream:
            n_seen_data += len(step.dataset)
            mem_fill = min(mem_size, n_seen_data)
            cl_strategy.train(step)
            ext_mem = replayPlugin.ext_mem
            lengths = []
            for task_id in ext_mem.keys():
                lengths.append(len(ext_mem[task_id]))
            self.assertEqual(sum(lengths), mem_fill)  # Always fully filled
コード例 #18
0
def main(args):
    # --- CONFIG
    device = torch.device(
        f"cuda:{args.cuda}"
        if torch.cuda.is_available() and args.cuda >= 0
        else "cpu"
    )

    # --- SCENARIO CREATION
    scenario = SplitCIFAR100(n_experiences=20, return_task_id=True)
    config = {"scenario": "SplitCIFAR100"}

    # MODEL CREATION
    model = MTSimpleCNN()

    # choose some metrics and evaluation method
    loggers = [InteractiveLogger()]
    if args.wandb_project != "":
        wandb_logger = WandBLogger(
            project_name=args.wandb_project,
            run_name="LaMAML_" + config["scenario"],
            config=config,
        )
        loggers.append(wandb_logger)

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(
            minibatch=True, epoch=True, experience=True, stream=True
        ),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        forgetting_metrics(experience=True),
        loggers=loggers,
    )

    # LAMAML STRATEGY
    rs_buffer = ReservoirSamplingBuffer(max_size=200)
    replay_plugin = ReplayPlugin(
        mem_size=200,
        batch_size=10,
        batch_size_mem=10,
        task_balanced_dataloader=False,
        storage_policy=rs_buffer,
    )

    cl_strategy = LaMAML(
        model,
        torch.optim.SGD(model.parameters(), lr=0.1),
        CrossEntropyLoss(),
        n_inner_updates=5,
        second_order=True,
        grad_clip_norm=1.0,
        learn_lr=True,
        lr_alpha=0.25,
        sync_update=False,
        train_mb_size=10,
        train_epochs=10,
        eval_mb_size=100,
        device=device,
        plugins=[replay_plugin],
        evaluator=eval_plugin,
    )

    # TRAINING LOOP
    print("Starting experiment...")
    results = []
    for experience in scenario.train_stream:
        print("Start of experience ", experience.current_experience)
        cl_strategy.train(experience)
        print("Training completed")

        print("Computing accuracy on the whole test set")
        results.append(cl_strategy.eval(scenario.test_stream))

    if args.wandb_project != "":
        wandb.finish()