def test_initialisation(self):
     module = SimpleMLP()
     module = module.to(self.device)
     old_classifier_weight = torch.clone(module.classifier.weight)
     old_classifier_bias = torch.clone(module.classifier.bias)
     module = as_multitask(module, "classifier")
     module = module.to(self.device)
     new_classifier_weight = torch.clone(
         module.classifier.classifiers["0"].classifier.weight)
     new_classifier_bias = torch.clone(
         module.classifier.classifiers["0"].classifier.bias)
     self.assertTrue(
         torch.equal(old_classifier_weight, new_classifier_weight))
     self.assertTrue(torch.equal(old_classifier_bias, new_classifier_bias))
    def _test_outputs(self, module, clf_name):
        test_input = torch.rand(10, 3, 32, 32)
        test_input = test_input.to(self.device)
        module_singletask = copy.deepcopy(module)
        module_multitask = as_multitask(module, clf_name)
        module_multitask = module_multitask.to(self.device)
        module_singletask = module_singletask.to(self.device)

        # Put in eval mode to deactivate dropouts
        module_singletask.eval()
        module_multitask.eval()

        out_single_task = module_singletask(test_input)
        out_multi_task = module_multitask(test_input, task_labels=0)
        self.assertTrue(torch.equal(out_single_task, out_multi_task))
Example #3
0
    def _test_integration(self, module, clf_name, plugins=[]):
        module = as_multitask(module, clf_name)
        module = module.to(self.device)
        optimizer = SGD(
            module.parameters(), lr=0.05, momentum=0.9, weight_decay=0.0002
        )

        strategy = Naive(
            module,
            optimizer,
            train_mb_size=32,
            eval_mb_size=32,
            device=self.device,
            plugins=plugins,
        )

        for t, experience in enumerate(self.benchmark.train_stream):
            strategy.train(experience)
            strategy.eval(self.benchmark.test_stream[: t + 1])
    def _test_modules(self, module, clf_name):
        old_param_total = sum([torch.numel(p) for p in module.parameters()])

        module = as_multitask(module, clf_name)
        module = module.to(self.device)
        self.assertIsInstance(module, MultiTaskModule)
        self.assertIsInstance(getattr(module, clf_name), MultiHeadClassifier)

        test_input = torch.ones(5, 3, 32, 32)
        task_labels = torch.zeros(5, dtype=torch.long)

        test_input = test_input.to(self.device)
        task_labels = task_labels.to(self.device)

        # One task label
        output = module(test_input, task_labels=0)

        # Several ones
        output = module(test_input, task_labels=task_labels)

        # Change attribute
        module.non_module_attribute = 10
        self.assertEqual(module.model.non_module_attribute, 10)

        module.non_module_attribute += 5

        # Extract params and state dict
        new_param_total = sum([torch.numel(p) for p in module.parameters()])
        self.assertEqual(new_param_total, old_param_total)
        state = module.state_dict()

        # Functions returning references
        module = module.train()
        self.assertIsInstance(module, MultiTaskModule)
        self.assertTrue(module.training)
        module = module.eval()
        self.assertIsInstance(module, MultiTaskModule)
        self.assertFalse(module.training)
        if self.device == "cuda":
            module = module.cuda()
            self.assertIsInstance(module, MultiTaskModule)