def test_classifier_selection(self): base_model = SimpleCNN() feature_extractor = base_model.features classifier1 = base_model.classifier classifier2 = NCMClassifier() model = TrainEvalModel( feature_extractor, train_classifier=classifier1, eval_classifier=classifier2, ) model.eval() model.adaptation() assert model.classifier is classifier2 model.train() model.adaptation() assert model.classifier is classifier1 model.eval_adaptation() assert model.classifier is classifier2 model.train_adaptation() assert model.classifier is classifier1
def __init__(self, feature_extractor: Module, classifier: Module, optimizer: Optimizer, memory_size, buffer_transform, fixed_memory, criterion=ICaRLLossPlugin(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = None, device=None, plugins: Optional[List[StrategyPlugin]] = None, evaluator: EvaluationPlugin = default_logger, eval_every=-1): """ iCaRL Strategy. This strategy does not use task identities. :param feature_extractor: The feature extractor. :param classifier: The differentiable classifier that takes as input the output of the feature extractor. :param optimizer: The optimizer to use. :param memory_size: The nuber of patterns saved in the memory. :param buffer_transform: transform applied on buffer elements already modified by test_transform (if specified) before being used for replay :param fixed_memory: If True a memory of size memory_size is allocated and partitioned between samples from the observed experiences. If False every time a new class is observed memory_size samples of that class are added to the memory. :param train_mb_size: The train minibatch size. Defaults to 1. :param train_epochs: The number of training epochs. Defaults to 1. :param eval_mb_size: The eval minibatch size. Defaults to 1. :param device: The device to use. Defaults to None (cpu). :param plugins: Plugins to be added. Defaults to None. :param evaluator: (optional) instance of EvaluationPlugin for logging and metric computations. :param eval_every: the frequency of the calls to `eval` inside the training loop. if -1: no evaluation during training. if 0: calls `eval` after the final epoch of each training experience. if >0: calls `eval` every `eval_every` epochs and at the end of all the epochs for a single experience. """ model = TrainEvalModel(feature_extractor, train_classifier=classifier, eval_classifier=NCMClassifier()) icarl = _ICaRLPlugin(memory_size, buffer_transform, fixed_memory) if plugins is None: plugins = [icarl] else: plugins += [icarl] if isinstance(criterion, StrategyPlugin): plugins += [criterion] super().__init__( model, optimizer, criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, device=device, plugins=plugins, evaluator=evaluator, eval_every=eval_every)
def test_ncm_classification(self): class_means = torch.tensor( [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float) mb_x = torch.tensor( [[4, 3, 2, 1], [3, 4, 2, 1], [3, 2, 4, 1], [3, 2, 1, 4]], dtype=torch.float) mb_y = torch.tensor([0, 1, 2, 3], dtype=torch.float) classifier = NCMClassifier(class_means) pred = classifier(mb_x) assert torch.all(torch.max(pred, 1)[1] == mb_y)