def test_device_of_output_head_is_correct(): """ There is a bug happening where the output head is on CPU while the rest of the model is on GPU. """ setting = ClassIncrementalSetting(dataset="mnist") method = BaselineMethod(max_epochs=1, no_wandb=True) results = setting.apply(method) assert 0.10 <= results.objective <= 0.30
def test_multiple_tasks_within_same_batch(mixed_samples: Dict[int, Tuple[Tensor, Tensor, Tensor]], indices: slice, monkeypatch, config: Config): """ TODO: Write out a test that checks that when given a batch with data from different tasks, and when the model is multiheaded, it will use the right output head for each image. """ setting = ClassIncrementalSetting() model = MultiHeadModel( setting=setting, hparams=MultiHeadModel.HParams(batch_size=30, multihead=True), config=config, ) class MockEncoder(nn.Module): def forward(self, x: Tensor): return x.new_ones([x.shape[0], model.hidden_size]) mock_encoder = MockEncoder() # monkeypatch.setattr(model, "forward", mock_encoder_forward) model.encoder = mock_encoder # model.output_task = mock_output_task # model.output_head = MockOutputHead( # input_space=spaces.Box(0, 1, [model.hidden_size]), # Actions=setting.Actions, # action_space=spaces.Discrete(2), # task_id=None, # ) for i in range(5): model.output_heads[str(i)] = MockOutputHead( input_space=spaces.Box(0, 1, [model.hidden_size]), Actions=setting.Actions, action_space=spaces.Discrete(2), task_id=i, ) model.output_head = model.output_heads["0"] xs, ys, ts = map(torch.cat, zip(*mixed_samples.values())) xs = xs[indices] ys = ys[indices] ts = ts[indices].int() obs = setting.Observations(x=xs, task_labels=ts) with torch.no_grad(): forward_pass = model(obs) y_preds = forward_pass["y_pred"] assert y_preds.shape == ts.shape assert torch.all(y_preds == ts * xs.view([xs.shape[0], -1]).mean(1))
def test_class_incremental_setting(): method = BaselineMethod(no_wandb=True, max_epochs=1) setting = ClassIncrementalSetting() results = setting.apply(method) print(results.summary()) assert results.final_performance_metrics[0].n_samples == 1984 assert results.final_performance_metrics[1].n_samples == 2016 assert results.final_performance_metrics[2].n_samples == 1984 assert results.final_performance_metrics[3].n_samples == 2016 assert results.final_performance_metrics[4].n_samples == 1984 assert 0.48 <= results.final_performance_metrics[0].accuracy <= 0.55 assert 0.48 <= results.final_performance_metrics[1].accuracy <= 0.55 assert 0.60 <= results.final_performance_metrics[2].accuracy <= 0.95 assert 0.75 <= results.final_performance_metrics[3].accuracy <= 0.98 assert 0.99 <= results.final_performance_metrics[4].accuracy <= 1.00
def test_task_inference_sl( mixed_samples: Dict[int, Tuple[Tensor, Tensor, Tensor]], indices: slice, config: Config, ): """ TODO: Write out a test that checks that when given a batch with data from different tasks, and when the model is multiheaded, it will use the right output head for each image. """ # Get a mixed batch xs, ys, ts = map(torch.cat, zip(*mixed_samples.values())) xs = xs[indices] ys = ys[indices] ts = ts[indices].int() obs = ClassIncrementalSetting.Observations(x=xs, task_labels=None) setting = ClassIncrementalSetting() model = MultiHeadModel( setting=setting, hparams=MultiHeadModel.HParams(batch_size=30, multihead=True), config=config, ) class MockEncoder(nn.Module): def forward(self, x: Tensor): return x.new_ones([x.shape[0], model.hidden_size]) mock_encoder = MockEncoder() model.encoder = mock_encoder for i in range(5): model.output_heads[str(i)] = MockOutputHead( input_space=spaces.Box(0, 1, [model.hidden_size]), action_space=spaces.Discrete(setting.action_space.n), Actions=setting.Actions, task_id=i, ) model.output_head = model.output_heads["0"] forward_pass = model(obs) y_preds = forward_pass.actions.y_pred assert y_preds.shape == ts.shape
def test_get_parents(): assert IIDSetting in TaskIncrementalSetting.get_children() assert IIDSetting in DomainIncrementalSetting.get_children() assert IIDSetting not in ClassIncrementalSetting.get_children() assert TaskIncrementalSetting in IIDSetting.get_immediate_parents() assert DomainIncrementalSetting in IIDSetting.get_immediate_parents() assert ClassIncrementalSetting not in IIDSetting.get_immediate_parents() assert TaskIncrementalSetting in IIDSetting.get_parents() assert DomainIncrementalSetting in IIDSetting.get_parents() assert ClassIncrementalSetting in IIDSetting.get_parents() assert IIDSetting not in IIDSetting.get_parents()
def configure(self, setting: ClassIncrementalSetting): # create the model self.net = models.resnet18(pretrained=False) self.net.fc = nn.Linear(512, setting.action_space.n) if torch.cuda.is_available(): self.net = self.net.to(device=self.device) # Set drop_last to True, to avoid getting a batch of size 1, which makes # batchnorm raise an error. setting.drop_last = True image_space: spaces.Box = setting.observation_space["x"] # Create the buffer. if self.buffer_capacity: self.buffer = Buffer( capacity=self.buffer_capacity, input_shape=image_space.shape, extra_buffers={"t": torch.LongTensor}, rng=self.rng, ).to(device=self.device) # Create the optimizer. self.optim = torch.optim.Adam( self.net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, )
# from sequoia.settings.sl.class_incremental.domain_incremental import DomainIncrementalSetting # setting = DomainIncrementalSetting( # dataset="mnist", nb_tasks=5, monitor_training_performance=True # ) # - "Medium": Class-Incremental MNIST Setting, useful for quick debugging: # setting = ClassIncrementalSetting( # dataset="mnist", # nb_tasks=5, # monitor_training_performance=True, # known_task_boundaries_at_test_time=False, # batch_size=32, # num_workers=4, # ) # - "HARD": Class-Incremental Synbols, more challenging. # NOTE: This Setting is very similar to the one used for the SL track of the # competition. setting = ClassIncrementalSetting( dataset="synbols", nb_tasks=12, known_task_boundaries_at_test_time=False, monitor_training_performance=True, batch_size=32, num_workers=4, ) # NOTE: can also use pass a `Config` object to `setting.apply`. This object has some # configuration options like device, data_dir, etc. results = setting.apply(method, config=Config(data_dir="data")) print(results.summary())