예제 #1
0
    def test_early_stop(self):
        class EarlyStopP(SupervisedPlugin):
            def after_training_iteration(
                self, strategy: "SupervisedTemplate", **kwargs
            ):
                if strategy.clock.train_epoch_iterations == 10:
                    strategy.stop_training()

        model = SimpleMLP(input_size=6, hidden_size=100)
        criterion = CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=1)

        strategy = Cumulative(
            model,
            optimizer,
            criterion,
            train_mb_size=1,
            device=get_device(),
            eval_mb_size=512,
            train_epochs=1,
            evaluator=None,
            plugins=[EarlyStopP()],
        )
        benchmark = get_fast_benchmark()

        for train_batch_info in benchmark.train_stream:
            strategy.train(train_batch_info)
            assert strategy.clock.train_epoch_iterations == 11
예제 #2
0
    def test_cumulative(self):
        # SIT scenario
        model, optimizer, criterion, my_nc_benchmark = self.init_sit()
        strategy = Cumulative(
            model,
            optimizer,
            criterion,
            train_mb_size=64,
            device=self.device,
            eval_mb_size=50,
            train_epochs=2,
        )
        self.run_strategy(my_nc_benchmark, strategy)

        # MT scenario
        strategy = Cumulative(
            model,
            optimizer,
            criterion,
            train_mb_size=64,
            device=self.device,
            eval_mb_size=50,
            train_epochs=2,
        )
        benchmark = self.load_benchmark(use_task_labels=True)
        self.run_strategy(benchmark, strategy)
    def test_multihead_cumulative(self):
        # check that multi-head reaches high enough accuracy.
        # Ensure nothing weird is happening with the multiple heads.

        set_deterministic_run(seed=0)

        model = MHTestMLP(input_size=6, hidden_size=100)
        criterion = CrossEntropyLoss()
        optimizer = SGD(model.parameters(), lr=1)

        main_metric = StreamAccuracy()
        exp_acc = ExperienceAccuracy()
        evalp = EvaluationPlugin(main_metric, exp_acc, loggers=None)
        strategy = Cumulative(
            model,
            optimizer,
            criterion,
            train_mb_size=64,
            device=get_device(),
            eval_mb_size=512,
            train_epochs=6,
            evaluator=evalp,
        )
        benchmark = get_fast_benchmark(use_task_labels=True)

        for train_batch_info in benchmark.train_stream:
            strategy.train(train_batch_info)
        strategy.eval(benchmark.train_stream[:])
        print("TRAIN STREAM ACC: ", main_metric.result())
        assert (sum(main_metric.result().values()) /
                float(len(main_metric.result().keys())) > 0.7)