コード例 #1
0
ファイル: test_strategies.py プロジェクト: rrmina/avalanche
    def test_cumulative(self):
        model = self.get_model(fast_test=self.fast_test)
        optimizer = SGD(model.parameters(), lr=1e-3)
        criterion = CrossEntropyLoss()

        # SIT scenario
        my_nc_scenario = self.load_scenario(fast_test=self.fast_test)
        strategy = Cumulative(model,
                              optimizer,
                              criterion,
                              train_mb_size=64,
                              device=self.device,
                              eval_mb_size=50,
                              train_epochs=2)
        self.run_strategy(my_nc_scenario, strategy)

        # MT scenario
        strategy = Cumulative(model,
                              optimizer,
                              criterion,
                              train_mb_size=64,
                              device=self.device,
                              eval_mb_size=50,
                              train_epochs=2)
        scenario = self.load_scenario(fast_test=self.fast_test,
                                      use_task_labels=True)
        self.run_strategy(scenario, strategy)
コード例 #2
0
ファイル: test_strategies.py プロジェクト: pkraison/avalanche
    def test_cumulative(self):
        # SIT scenario
        model, optimizer, criterion, my_nc_benchmark = self.init_sit()
        strategy = Cumulative(
            model,
            optimizer,
            criterion,
            train_mb_size=64,
            device=self.device,
            eval_mb_size=50,
            train_epochs=2,
        )
        self.run_strategy(my_nc_benchmark, strategy)

        # MT scenario
        strategy = Cumulative(
            model,
            optimizer,
            criterion,
            train_mb_size=64,
            device=self.device,
            eval_mb_size=50,
            train_epochs=2,
        )
        benchmark = self.load_benchmark(use_task_labels=True)
        self.run_strategy(benchmark, strategy)
コード例 #3
0
    def test_early_stop(self):
        class EarlyStopP(StrategyPlugin):
            def after_training_iteration(self, strategy: 'BaseStrategy',
                                         **kwargs):
                if strategy.mb_it == 10:
                    strategy.stop_training()

        model = SimpleMLP(input_size=6, hidden_size=100)
        criterion = CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=1)

        strategy = Cumulative(model,
                              optimizer,
                              criterion,
                              train_mb_size=1,
                              device=get_device(),
                              eval_mb_size=512,
                              train_epochs=1,
                              evaluator=None,
                              plugins=[EarlyStopP()])
        scenario = get_fast_scenario()

        for train_batch_info in scenario.train_stream:
            strategy.train(train_batch_info)
            assert strategy.mb_it == 11
コード例 #4
0
    def test_multihead_cumulative(self):
        # check that multi-head reaches high enough accuracy.
        # Ensure nothing weird is happening with the multiple heads.
        model = MHTestMLP(input_size=6, hidden_size=100)
        criterion = CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=1)

        main_metric = StreamAccuracy()
        exp_acc = ExperienceAccuracy()
        evalp = EvaluationPlugin(main_metric, exp_acc, loggers=None)
        strategy = Cumulative(model,
                              optimizer,
                              criterion,
                              train_mb_size=32,
                              device=get_device(),
                              eval_mb_size=512,
                              train_epochs=1,
                              evaluator=evalp)
        scenario = get_fast_scenario(use_task_labels=True)

        for train_batch_info in scenario.train_stream:
            strategy.train(train_batch_info)
        strategy.eval(scenario.train_stream[:])
        print("TRAIN STREAM ACC: ", main_metric.result())
        assert sum(main_metric.result().values()) / \
               float(len(main_metric.result().keys())) > 0.7