def test_pnn(self):
        # check that pnn reaches high enough accuracy.
        # Ensure nothing weird is happening with the multiple heads.

        set_deterministic_run(seed=0)

        main_metric = StreamAccuracy()
        exp_acc = ExperienceAccuracy()
        evalp = EvaluationPlugin(main_metric, exp_acc, loggers=None)
        model = PNN(num_layers=1, in_features=6, hidden_features_per_column=50)
        optimizer = SGD(model.parameters(), lr=0.1)
        strategy = PNNStrategy(
            model,
            optimizer,
            train_mb_size=32,
            device=get_device(),
            eval_mb_size=512,
            train_epochs=1,
            evaluator=evalp,
        )
        benchmark = get_fast_benchmark(use_task_labels=True)

        for train_batch_info in benchmark.train_stream:
            strategy.train(train_batch_info)

        strategy.eval(benchmark.train_stream[:])
        print("TRAIN STREAM ACC: ", main_metric.result())
        assert (sum(main_metric.result().values()) /
                float(len(main_metric.result().keys())) > 0.5)
Пример #2
0
    def test_pnn(self):
        # only multi-task scenarios.
        # eval on future tasks is not allowed.
        model = PNN(num_layers=3, in_features=6, hidden_features_per_column=10)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        strategy = PNNStrategy(
            model,
            optimizer,
            train_mb_size=10,
            device=self.device,
            eval_mb_size=50,
            train_epochs=2,
        )

        # train and test loop
        benchmark = self.load_benchmark(use_task_labels=True)
        for train_task in benchmark.train_stream:
            strategy.train(train_task)
        strategy.eval(benchmark.test_stream)