示例#1
0
    def test_incremental_classifier_weight_update(self):
        model = IncrementalClassifier(in_features=10)
        optimizer = SGD(model.parameters(), lr=1e-3)
        criterion = CrossEntropyLoss()
        scenario = self.scenario

        strategy = Naive(model,
                         optimizer,
                         criterion,
                         train_mb_size=100,
                         train_epochs=1,
                         eval_mb_size=100,
                         device='cpu')
        strategy.evaluator.loggers = [TextLogger(sys.stdout)]

        # train on first task
        w_old = model.classifier.weight.clone()
        b_old = model.classifier.bias.clone()

        # adaptation. Increase number of classes
        dataset = scenario.train_stream[4].dataset
        model.adaptation(dataset)
        w_new = model.classifier.weight.clone()
        b_new = model.classifier.bias.clone()

        # old weights should be copied correctly.
        assert torch.equal(w_old, w_new[:w_old.shape[0]])
        assert torch.equal(b_old, b_new[:w_old.shape[0]])

        # shape should be correct.
        assert w_new.shape[0] == max(dataset.targets) + 1
        assert b_new.shape[0] == max(dataset.targets) + 1
示例#2
0
    def test_incremental_classifier(self):
        model = SimpleMLP(input_size=6, hidden_size=10)
        model.classifier = IncrementalClassifier(in_features=10)
        optimizer = SGD(model.parameters(), lr=1e-3)
        criterion = CrossEntropyLoss()
        scenario = self.scenario

        strategy = Naive(model,
                         optimizer,
                         criterion,
                         train_mb_size=100,
                         train_epochs=1,
                         eval_mb_size=100,
                         device='cpu')
        strategy.evaluator.loggers = [TextLogger(sys.stdout)]
        print("Current Classes: ",
              scenario.train_stream[0].classes_in_this_experience)
        print("Current Classes: ",
              scenario.train_stream[4].classes_in_this_experience)

        # train on first task
        strategy.train(scenario.train_stream[0])
        w_ptr = model.classifier.classifier.weight.data_ptr()
        b_ptr = model.classifier.classifier.bias.data_ptr()
        opt_params_ptrs = [
            w.data_ptr() for group in optimizer.param_groups
            for w in group['params']
        ]
        # classifier params should be optimized
        assert w_ptr in opt_params_ptrs
        assert b_ptr in opt_params_ptrs

        # train again on the same task.
        strategy.train(scenario.train_stream[0])
        # parameters should not change.
        assert w_ptr == model.classifier.classifier.weight.data_ptr()
        assert b_ptr == model.classifier.classifier.bias.data_ptr()
        # the same classifier params should still be optimized
        assert w_ptr in opt_params_ptrs
        assert b_ptr in opt_params_ptrs

        # update classifier with new classes.
        old_w_ptr, old_b_ptr = w_ptr, b_ptr
        strategy.train(scenario.train_stream[4])
        opt_params_ptrs = [
            w.data_ptr() for group in optimizer.param_groups
            for w in group['params']
        ]
        new_w_ptr = model.classifier.classifier.weight.data_ptr()
        new_b_ptr = model.classifier.classifier.bias.data_ptr()
        # weights should change.
        assert old_w_ptr != new_w_ptr
        assert old_b_ptr != new_b_ptr
        # Old params should not be optimized. New params should be optimized.
        assert old_w_ptr not in opt_params_ptrs
        assert old_b_ptr not in opt_params_ptrs
        assert new_w_ptr in opt_params_ptrs
        assert new_b_ptr in opt_params_ptrs
    def test_periodic_eval(self):
        model = SimpleMLP(input_size=6, hidden_size=10)
        model.classifier = IncrementalClassifier(model.classifier.in_features)
        benchmark = get_fast_benchmark()
        optimizer = SGD(model.parameters(), lr=1e-3)
        criterion = CrossEntropyLoss()
        curve_key = "Top1_Acc_Stream/eval_phase/train_stream/Task000"

        ###################
        # Case #1: No eval
        ###################
        # we use stream acc. because it emits a single value
        # for each eval loop.
        acc = StreamAccuracy()
        strategy = Naive(
            model,
            optimizer,
            criterion,
            train_epochs=2,
            eval_every=-1,
            evaluator=EvaluationPlugin(acc),
        )
        strategy.train(benchmark.train_stream[0])
        # eval is not called in this case
        assert len(strategy.evaluator.get_all_metrics()) == 0

        ###################
        # Case #2: Eval at the end only and before training
        ###################
        acc = StreamAccuracy()
        evalp = EvaluationPlugin(acc)
        strategy = Naive(
            model,
            optimizer,
            criterion,
            train_epochs=2,
            eval_every=0,
            evaluator=evalp,
        )
        strategy.train(benchmark.train_stream[0])
        # eval is called once at the end of the training loop
        curve = strategy.evaluator.get_all_metrics()[curve_key][1]
        assert len(curve) == 2

        ###################
        # Case #3: Eval after every epoch and before training
        ###################
        acc = StreamAccuracy()
        strategy = Naive(
            model,
            optimizer,
            criterion,
            train_epochs=2,
            eval_every=1,
            evaluator=EvaluationPlugin(acc),
        )
        strategy.train(benchmark.train_stream[0])
        curve = strategy.evaluator.get_all_metrics()[curve_key][1]
        assert len(curve) == 3

        ###################
        # Case #4: Eval in iteration mode
        ###################
        acc = StreamAccuracy()
        strategy = Naive(
            model,
            optimizer,
            criterion,
            train_epochs=2,
            eval_every=100,
            evaluator=EvaluationPlugin(acc),
            peval_mode="iteration",
        )
        strategy.train(benchmark.train_stream[0])
        curve = strategy.evaluator.get_all_metrics()[curve_key][1]
        assert len(curve) == 5