示例#1
0
    def test_dataload_batch_balancing(self):
        scenario = get_fast_scenario()
        model = SimpleMLP(input_size=6, hidden_size=10)
        batch_size = 32
        replayPlugin = ReplayPlugin(mem_size=20)
        cl_strategy = Naive(model,
                            SGD(model.parameters(),
                                lr=0.001,
                                momentum=0.9,
                                weight_decay=0.001),
                            CrossEntropyLoss(),
                            train_mb_size=batch_size,
                            train_epochs=1,
                            eval_mb_size=100,
                            plugins=[replayPlugin])

        for step in scenario.train_stream:
            adapted_dataset = step.dataset
            dataloader = MultiTaskJoinedBatchDataLoader(
                adapted_dataset,
                AvalancheConcatDataset(replayPlugin.ext_mem.values()),
                oversample_small_tasks=True,
                num_workers=0,
                batch_size=batch_size,
                shuffle=True)

            for mini_batch in dataloader:
                lengths = []
                for task_id in mini_batch.keys():
                    lengths.append(len(mini_batch[task_id][1]))
                if sum(lengths) == batch_size:
                    difference = max(lengths) - min(lengths)
                    self.assertLessEqual(difference, 1)
                self.assertLessEqual(sum(lengths), batch_size)
            cl_strategy.train(step)
示例#2
0
    def before_training_exp(self,
                            strategy: BaseStrategy,
                            num_workers=0,
                            shuffle=True,
                            **kwargs):
        if not self.memory:
            return

        strategy.dataloader = MultiTaskJoinedBatchDataLoader(
            strategy.adapted_dataset,
            self.memory.dataset,
            batch_size=strategy.train_mb_size,
            shuffle=shuffle,
            num_workers=num_workers,
            oversample_small_tasks=True,
        )
示例#3
0
 def before_training_exp(self,
                         strategy,
                         num_workers=0,
                         shuffle=True,
                         **kwargs):
     """
     Dataloader to build batches containing examples from both memories and
     the training dataset
     """
     strategy.current_dataloader = MultiTaskJoinedBatchDataLoader(
         strategy.adapted_dataset,
         self.ext_mem,
         oversample_small_tasks=True,
         num_workers=num_workers,
         batch_size=strategy.train_mb_size,
         shuffle=shuffle)
示例#4
0
 def before_training_exp(self,
                         strategy,
                         num_workers=0,
                         shuffle=True,
                         **kwargs):
     """
     Dataloader to build batches containing examples from both memories and
     the training dataset
     """
     if len(self.ext_mem) == 0:
         return
     strategy.dataloader = MultiTaskJoinedBatchDataLoader(
         strategy.adapted_dataset,
         AvalancheConcatDataset(self.ext_mem.values()),
         oversample_small_tasks=True,
         num_workers=num_workers,
         batch_size=strategy.train_mb_size,
         shuffle=shuffle)
示例#5
0
 def before_training_exp(self,
                         strategy,
                         num_workers=0,
                         shuffle=True,
                         **kwargs):
     """
     Random retrieval from a class-balanced memory.
     Dataloader builds batches containing examples from both memories and
     the training dataset.
     """
     if len(self.replay_mem) == 0:
         return
     strategy.dataloader = MultiTaskJoinedBatchDataLoader(
         strategy.adapted_dataset,
         AvalancheConcatDataset(self.replay_mem.values()),
         oversample_small_tasks=True,
         num_workers=num_workers,
         batch_size=strategy.train_mb_size * 2,
         shuffle=shuffle)