Esempio n. 1
0
    def test_classifier_selection(self):
        base_model = SimpleCNN()

        feature_extractor = base_model.features
        classifier1 = base_model.classifier
        classifier2 = NCMClassifier()

        model = TrainEvalModel(
            feature_extractor,
            train_classifier=classifier1,
            eval_classifier=classifier2,
        )

        model.eval()
        model.adaptation()
        assert model.classifier is classifier2

        model.train()
        model.adaptation()
        assert model.classifier is classifier1

        model.eval_adaptation()
        assert model.classifier is classifier2

        model.train_adaptation()
        assert model.classifier is classifier1
Esempio n. 2
0
    def __init__(self, feature_extractor: Module, classifier: Module,
                 optimizer: Optimizer, memory_size, buffer_transform,
                 fixed_memory, criterion=ICaRLLossPlugin(),
                 train_mb_size: int = 1, train_epochs: int = 1,
                 eval_mb_size: int = None, device=None,
                 plugins: Optional[List[StrategyPlugin]] = None,
                 evaluator: EvaluationPlugin = default_logger, eval_every=-1):
        """ iCaRL Strategy.
        This strategy does not use task identities.

        :param feature_extractor: The feature extractor.
        :param classifier: The differentiable classifier that takes as input
            the output of the feature extractor.
        :param optimizer: The optimizer to use.
        :param memory_size: The nuber of patterns saved in the memory.
        :param buffer_transform: transform applied on buffer elements already
            modified by test_transform (if specified) before being used for
             replay
        :param fixed_memory: If True a memory of size memory_size is
            allocated and partitioned between samples from the observed
            experiences. If False every time a new class is observed
            memory_size samples of that class are added to the memory.
        :param train_mb_size: The train minibatch size. Defaults to 1.
        :param train_epochs: The number of training epochs. Defaults to 1.
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
        :param device: The device to use. Defaults to None (cpu).
        :param plugins: Plugins to be added. Defaults to None.
        :param evaluator: (optional) instance of EvaluationPlugin for logging
            and metric computations.
        :param eval_every: the frequency of the calls to `eval` inside the
            training loop.
                if -1: no evaluation during training.
                if  0: calls `eval` after the final epoch of each training
                    experience.
                if >0: calls `eval` every `eval_every` epochs and at the end
                    of all the epochs for a single experience.
        """
        model = TrainEvalModel(feature_extractor,
                               train_classifier=classifier,
                               eval_classifier=NCMClassifier())

        icarl = _ICaRLPlugin(memory_size, buffer_transform, fixed_memory)

        if plugins is None:
            plugins = [icarl]
        else:
            plugins += [icarl]

        if isinstance(criterion, StrategyPlugin):
            plugins += [criterion]

        super().__init__(
            model, optimizer, criterion=criterion,
            train_mb_size=train_mb_size, train_epochs=train_epochs,
            eval_mb_size=eval_mb_size, device=device, plugins=plugins,
            evaluator=evaluator, eval_every=eval_every)
Esempio n. 3
0
    def test_ncm_classification(self):
        class_means = torch.tensor(
            [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
            dtype=torch.float)

        mb_x = torch.tensor(
            [[4, 3, 2, 1], [3, 4, 2, 1], [3, 2, 4, 1], [3, 2, 1, 4]],
            dtype=torch.float)

        mb_y = torch.tensor([0, 1, 2, 3], dtype=torch.float)

        classifier = NCMClassifier(class_means)

        pred = classifier(mb_x)
        assert torch.all(torch.max(pred, 1)[1] == mb_y)