def test_trainer_can_run(self): trainer = CallbackTrainer(model=self.model, optimizer=self.optimizer, callbacks=self.default_callbacks(serialization_dir=None), num_epochs=2) metrics = trainer.train() assert 'best_validation_loss' in metrics assert isinstance(metrics['best_validation_loss'], float) assert 'best_validation_accuracy' in metrics assert isinstance(metrics['best_validation_accuracy'], float) assert 'best_validation_accuracy3' in metrics assert isinstance(metrics['best_validation_accuracy3'], float) assert 'best_epoch' in metrics assert isinstance(metrics['best_epoch'], int) assert 'peak_cpu_memory_MB' in metrics # Making sure that both increasing and decreasing validation metrics work. trainer = CallbackTrainer(model=self.model, optimizer=self.optimizer, callbacks=self.default_callbacks(validation_metric="+loss", serialization_dir=None), num_epochs=2) metrics = trainer.train() assert 'best_validation_loss' in metrics assert isinstance(metrics['best_validation_loss'], float) assert 'best_validation_accuracy' in metrics assert isinstance(metrics['best_validation_accuracy'], float) assert 'best_validation_accuracy3' in metrics assert isinstance(metrics['best_validation_accuracy3'], float) assert 'best_epoch' in metrics assert isinstance(metrics['best_epoch'], int) assert 'peak_cpu_memory_MB' in metrics assert isinstance(metrics['peak_cpu_memory_MB'], float) assert metrics['peak_cpu_memory_MB'] > 0
def test_trainer_can_run_cuda(self): self.model.cuda() trainer = CallbackTrainer(self.model, self.optimizer, num_epochs=2, callbacks=self.default_callbacks(), cuda_device=0) trainer.train()
def test_trainer_can_resume_with_lr_scheduler(self): lr_scheduler = LearningRateScheduler.from_params( self.optimizer, Params({"type": "exponential", "gamma": 0.5})) callbacks = self.default_callbacks() + [UpdateLearningRate(lr_scheduler)] trainer = CallbackTrainer(model=self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, callbacks=callbacks, num_epochs=2, serialization_dir=self.TEST_DIR) trainer.train() new_lr_scheduler = LearningRateScheduler.from_params( self.optimizer, Params({"type": "exponential", "gamma": 0.5})) callbacks = self.default_callbacks() + [UpdateLearningRate(new_lr_scheduler)] new_trainer = CallbackTrainer(model=self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, callbacks=callbacks, num_epochs=4, serialization_dir=self.TEST_DIR) new_trainer.handler.fire_event(Events.TRAINING_START) assert new_trainer.epoch_number == 2 assert new_lr_scheduler.lr_scheduler.last_epoch == 1 new_trainer.train()
def test_restored_training_returns_best_epoch_metrics_even_if_no_better_epoch_is_found_after_restoring(self): # Instead of -loss, use +loss to assure 2nd epoch is considered worse. # Run 1 epoch of original training. original_trainer = CallbackTrainer(self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, callbacks=self.default_callbacks(validation_metric="+loss"), num_epochs=1, serialization_dir=self.TEST_DIR) training_metrics = original_trainer.train() # Run 1 epoch of restored training. restored_trainer = CallbackTrainer( self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, callbacks=self.default_callbacks(validation_metric="+loss"), num_epochs=2, serialization_dir=self.TEST_DIR) restored_metrics = restored_trainer.train() assert "best_validation_loss" in restored_metrics assert "best_validation_accuracy" in restored_metrics assert "best_validation_accuracy3" in restored_metrics assert "best_epoch" in restored_metrics # Epoch 2 validation loss should be lesser than that of Epoch 1 assert training_metrics["best_validation_loss"] == restored_metrics["best_validation_loss"] assert training_metrics["best_epoch"] == 0 assert training_metrics["validation_loss"] > restored_metrics["validation_loss"]
def test_trainer_can_resume_training_for_exponential_moving_average(self): moving_average = ExponentialMovingAverage(self.model.named_parameters()) callbacks = self.default_callbacks() + [UpdateMovingAverage(moving_average)] trainer = CallbackTrainer(self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, num_epochs=1, serialization_dir=self.TEST_DIR, callbacks=callbacks) trainer.train() new_moving_average = ExponentialMovingAverage(self.model.named_parameters()) new_callbacks = self.default_callbacks() + [UpdateMovingAverage(new_moving_average)] new_trainer = CallbackTrainer(self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, num_epochs=3, serialization_dir=self.TEST_DIR, callbacks=new_callbacks) new_trainer.handler.fire_event(Events.TRAINING_START) # pylint: disable=protected-access assert new_trainer.epoch_number == 1 tracker = trainer.metric_tracker # pylint: disable=protected-access assert tracker.is_best_so_far() assert tracker._best_so_far is not None # pylint: disable=protected-access new_trainer.train()
def test_trainer_can_run_and_resume_with_momentum_scheduler(self): scheduler = MomentumScheduler.from_params( self.optimizer, Params({"type": "inverted_triangular", "cool_down": 2, "warm_up": 2})) callbacks = self.default_callbacks() + [UpdateMomentum(scheduler)] trainer = CallbackTrainer(model=self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, num_epochs=4, callbacks=callbacks, serialization_dir=self.TEST_DIR) trainer.train() new_scheduler = MomentumScheduler.from_params( self.optimizer, Params({"type": "inverted_triangular", "cool_down": 2, "warm_up": 2})) new_callbacks = self.default_callbacks() + [UpdateMomentum(new_scheduler)] new_trainer = CallbackTrainer(model=self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, num_epochs=6, callbacks=new_callbacks, serialization_dir=self.TEST_DIR) new_trainer.handler.fire_event(Events.TRAINING_START) assert new_trainer.epoch_number == 4 assert new_scheduler.last_epoch == 3 new_trainer.train()
def test_production_rule_field_with_multiple_gpus(self): wikitables_dir = "allennlp/tests/fixtures/data/wikitables/" offline_lf_directory = wikitables_dir + "action_space_walker_output/" wikitables_reader = WikiTablesDatasetReader( tables_directory=wikitables_dir, offline_logical_forms_directory=offline_lf_directory) instances = wikitables_reader.read(wikitables_dir + "sample_data.examples") archive_path = (self.FIXTURES_ROOT / "semantic_parsing" / "wikitables" / "serialization" / "model.tar.gz") model = load_archive(archive_path).model model.cuda() multigpu_iterator = BasicIterator(batch_size=4) multigpu_iterator.index_with(model.vocab) trainer = CallbackTrainer( model, instances, multigpu_iterator, self.optimizer, num_epochs=2, cuda_device=[0, 1], callbacks=[GradientNormAndClip()], ) trainer.train()
def test_trainer_saves_and_loads_best_validation_metrics_correctly_2(self): # Use -loss and run 1 epoch of original-training, and one of restored-training # Run 1 epoch of original training. trainer = CallbackTrainer( self.model, self.optimizer, callbacks=self.default_callbacks(validation_metric="+loss"), num_epochs=1, serialization_dir=self.TEST_DIR) trainer.train() _ = trainer.handler.fire_event(Events.RESTORE_CHECKPOINT) best_epoch_1 = trainer.metric_tracker.best_epoch best_validation_metrics_epoch_1 = trainer.metric_tracker.best_epoch_metrics # best_validation_metrics_epoch_1: {'accuracy': 0.75, 'accuracy3': 1.0, 'loss': 0.6243013441562653} assert isinstance(best_validation_metrics_epoch_1, dict) assert "loss" in best_validation_metrics_epoch_1 # Run 1 more epoch of restored training. restore_trainer = CallbackTrainer( self.model, self.optimizer, callbacks=self.default_callbacks(validation_metric="+loss"), num_epochs=2, serialization_dir=self.TEST_DIR) restore_trainer.train() _ = restore_trainer.handler.fire_event(Events.RESTORE_CHECKPOINT) best_epoch_2 = restore_trainer.metric_tracker.best_epoch best_validation_metrics_epoch_2 = restore_trainer.metric_tracker.best_epoch_metrics # Because of using +loss, 2nd epoch won't be better than 1st. So best val metrics should be same. assert best_epoch_1 == best_epoch_2 == 0 assert best_validation_metrics_epoch_2 == best_validation_metrics_epoch_1
def test_trainer_can_resume_training(self): trainer = CallbackTrainer(self.model, self.optimizer, callbacks=self.default_callbacks(), num_epochs=1, serialization_dir=self.TEST_DIR) trainer.train() new_trainer = CallbackTrainer(self.model, self.optimizer, callbacks=self.default_callbacks(), num_epochs=3, serialization_dir=self.TEST_DIR) new_trainer.handler.fire_event(Events.RESTORE_CHECKPOINT) assert new_trainer.epoch_number == 1 tracker = new_trainer.metric_tracker assert tracker is not None assert tracker.is_best_so_far() assert tracker._best_so_far is not None new_trainer.train()
def test_trainer_can_resume_with_lr_scheduler(self): lr_scheduler = LearningRateScheduler.from_params( self.optimizer, Params({ "type": "exponential", "gamma": 0.5 })) callbacks = self.default_callbacks() + [LrsCallback(lr_scheduler)] trainer = CallbackTrainer(model=self.model, optimizer=self.optimizer, callbacks=callbacks, num_epochs=2, serialization_dir=self.TEST_DIR) trainer.train() new_lr_scheduler = LearningRateScheduler.from_params( self.optimizer, Params({ "type": "exponential", "gamma": 0.5 })) callbacks = self.default_callbacks() + [LrsCallback(new_lr_scheduler)] new_trainer = CallbackTrainer(model=self.model, optimizer=self.optimizer, callbacks=callbacks, num_epochs=4, serialization_dir=self.TEST_DIR) new_trainer.handler.fire_event(Events.RESTORE_CHECKPOINT) assert new_trainer.epoch_number == 2 assert new_lr_scheduler.lr_scheduler.last_epoch == 1 new_trainer.train()
def test_trainer_saves_and_loads_best_validation_metrics_correctly_1(self): # Use -loss and run 1 epoch of original-training, and one of restored-training # Run 1 epoch of original training. trainer = CallbackTrainer( self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, callbacks=self.default_callbacks(), num_epochs=1, serialization_dir=self.TEST_DIR) trainer.train() trainer.handler.fire_event(Events.TRAINING_START) best_epoch_1 = trainer.metric_tracker.best_epoch best_validation_metrics_epoch_1 = trainer.metric_tracker.best_epoch_metrics # best_validation_metrics_epoch_1: {'accuracy': 0.75, 'accuracy3': 1.0, 'loss': 0.6243013441562653} assert isinstance(best_validation_metrics_epoch_1, dict) assert "loss" in best_validation_metrics_epoch_1 # Run 1 epoch of restored training. restore_trainer = CallbackTrainer( self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, callbacks=self.default_callbacks(), num_epochs=2, serialization_dir=self.TEST_DIR) restore_trainer.train() restore_trainer.handler.fire_event(Events.TRAINING_START) best_epoch_2 = restore_trainer.metric_tracker.best_epoch best_validation_metrics_epoch_2 = restore_trainer.metric_tracker.best_epoch_metrics # Because of using -loss, 2nd epoch would be better than 1st. So best val metrics should not be same. assert best_epoch_1 == 0 and best_epoch_2 == 1 assert best_validation_metrics_epoch_2 != best_validation_metrics_epoch_1
def test_trainer_can_resume_training(self): trainer = CallbackTrainer( self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, callbacks=self.default_callbacks(), num_epochs=1, serialization_dir=self.TEST_DIR, ) trainer.train() new_trainer = CallbackTrainer( self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, callbacks=self.default_callbacks(), num_epochs=3, serialization_dir=self.TEST_DIR, ) new_trainer.handler.fire_event(Events.TRAINING_START) assert new_trainer.epoch_number == 1 tracker = new_trainer.metric_tracker assert tracker is not None assert tracker.is_best_so_far() assert tracker._best_so_far is not None new_trainer.train()
def test_validation_metrics_consistent_with_and_without_tracking(self): default_callbacks = self.default_callbacks(serialization_dir=None) default_callbacks_without_tracking = [ callback for callback in default_callbacks if not isinstance(callback, TrackMetrics) ] trainer1 = CallbackTrainer( copy.deepcopy(self.model), training_data=self.instances, iterator=self.iterator, optimizer=copy.deepcopy(self.optimizer), callbacks=default_callbacks_without_tracking, num_epochs=1, serialization_dir=None, ) trainer1.train() trainer2 = CallbackTrainer( copy.deepcopy(self.model), training_data=self.instances, iterator=self.iterator, optimizer=copy.deepcopy(self.optimizer), callbacks=default_callbacks, num_epochs=1, serialization_dir=None, ) trainer2.train() metrics1 = trainer1.val_metrics metrics2 = trainer2.val_metrics assert metrics1.keys() == metrics2.keys() for key in ["accuracy", "accuracy3", "loss"]: np.testing.assert_almost_equal(metrics1[key], metrics2[key])
def test_training_metrics_consistent_with_and_without_validation(self): default_callbacks = self.default_callbacks(serialization_dir=None) default_callbacks_without_validation = [ callback for callback in default_callbacks if not isinstance(callback, Validate) ] trainer1 = CallbackTrainer( copy.deepcopy(self.model), copy.deepcopy(self.optimizer), callbacks=default_callbacks_without_validation, num_epochs=1, serialization_dir=None) trainer1.train() trainer2 = CallbackTrainer(copy.deepcopy(self.model), copy.deepcopy(self.optimizer), callbacks=default_callbacks, num_epochs=1, serialization_dir=None) trainer2.train() metrics1 = trainer1.train_metrics metrics2 = trainer2.train_metrics assert metrics1.keys() == metrics2.keys() for key in ['accuracy', 'accuracy3', 'loss']: np.testing.assert_almost_equal(metrics1[key], metrics2[key])
def test_production_rule_field_with_multiple_gpus(self): wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/' offline_lf_directory = wikitables_dir + 'action_space_walker_output/' wikitables_reader = WikiTablesDatasetReader( tables_directory=wikitables_dir, offline_logical_forms_directory=offline_lf_directory) instances = wikitables_reader.read(wikitables_dir + 'sample_data.examples') archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz' model = load_archive(archive_path).model model.cuda() multigpu_iterator = BasicIterator(batch_size=4) multigpu_iterator.index_with(model.vocab) trainer = CallbackTrainer(model, self.optimizer, num_epochs=2, cuda_device=[0, 1], callbacks=[ GenerateTrainingBatches( instances, multigpu_iterator), TrainSupervised() ]) trainer.train()
def test_trainer_can_run_exponential_moving_average(self): moving_average = ExponentialMovingAverage(self.model.named_parameters(), decay=0.9999) callbacks = self.default_callbacks() + [ComputeMovingAverage(moving_average)] trainer = CallbackTrainer(model=self.model, optimizer=self.optimizer, num_epochs=2, callbacks=callbacks) trainer.train()
def test_trainer_raises_on_model_with_no_loss_key(self): class FakeModel(Model): def forward(self, **kwargs): # pylint: disable=arguments-differ,unused-argument return {} with pytest.raises(RuntimeError): trainer = CallbackTrainer(FakeModel(None), self.optimizer, callbacks=self.default_callbacks(), num_epochs=2, serialization_dir=self.TEST_DIR) trainer.train()
def test_trainer_can_run_ema_from_params(self): uma_params = Params({"moving_average": {"decay": 0.9999}}) callbacks = self.default_callbacks() + [UpdateMovingAverage.from_params(uma_params, self.model)] trainer = CallbackTrainer(model=self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, num_epochs=2, callbacks=callbacks) trainer.train()
def test_trainer_can_run_cuda(self): self.model.cuda() trainer = CallbackTrainer(self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, num_epochs=2, callbacks=self.default_callbacks(), cuda_device=0) trainer.train()
def test_trainer_can_run_with_lr_scheduler(self): lr_params = Params({"type": "reduce_on_plateau"}) lr_scheduler = LearningRateScheduler.from_params(self.optimizer, lr_params) callbacks = self.default_callbacks() + [UpdateLearningRate(lr_scheduler)] trainer = CallbackTrainer(model=self.model, optimizer=self.optimizer, callbacks=callbacks, num_epochs=2) trainer.train()
def test_trainer_saves_models_at_specified_interval(self): iterator = BasicIterator(batch_size=4) iterator.index_with(self.vocab) trainer = CallbackTrainer( self.model, training_data=self.instances, iterator=iterator, optimizer=self.optimizer, num_epochs=2, serialization_dir=self.TEST_DIR, callbacks=self.default_callbacks(model_save_interval=0.0001), ) trainer.train() # Now check the serialized files for models saved during the epoch. prefix = "model_state_epoch_*" file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix))) epochs = [ re.search(r"_([0-9\.\-]+)\.th", fname).group(1) for fname in file_names ] # We should have checkpoints at the end of each epoch and during each, e.g. # [0.timestamp, 0, 1.timestamp, 1] assert len(epochs) == 4 assert epochs[3] == "1" assert "." in epochs[0] # Now make certain we can restore from timestamped checkpoint. # To do so, remove the checkpoint from the end of epoch 1&2, so # that we are forced to restore from the timestamped checkpoints. for k in range(2): os.remove( os.path.join(self.TEST_DIR, "model_state_epoch_{}.th".format(k))) os.remove( os.path.join(self.TEST_DIR, "training_state_epoch_{}.th".format(k))) os.remove(os.path.join(self.TEST_DIR, "best.th")) restore_trainer = CallbackTrainer( self.model, training_data=self.instances, iterator=iterator, optimizer=self.optimizer, num_epochs=2, serialization_dir=self.TEST_DIR, callbacks=self.default_callbacks(model_save_interval=0.0001), ) restore_trainer.handler.fire_event(Events.TRAINING_START) assert restore_trainer.epoch_number == 2 # One batch per epoch. assert restore_trainer.batch_num_total == 2
def test_trainer_can_run_exponential_moving_average(self): moving_average = ExponentialMovingAverage(self.model.named_parameters(), decay=0.9999) callbacks = self.default_callbacks() + [UpdateMovingAverage(moving_average)] trainer = CallbackTrainer( model=self.model, training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, num_epochs=2, callbacks=callbacks, ) trainer.train()
def test_trainer_can_log_learning_rates_tensorboard(self): callbacks = [cb for cb in self.default_callbacks() if not isinstance(cb, LogToTensorboard)] # The lambda: None is unfortunate, but it will get replaced by the callback. tensorboard = TensorboardWriter(lambda: None, should_log_learning_rate=True, summary_interval=2) callbacks.append(LogToTensorboard(tensorboard)) trainer = CallbackTrainer(self.model, self.optimizer, num_epochs=2, serialization_dir=self.TEST_DIR, callbacks=callbacks) trainer.train()
def test_trainer_respects_num_serialized_models_to_keep(self): trainer = CallbackTrainer(self.model, self.optimizer, num_epochs=5, serialization_dir=self.TEST_DIR, callbacks=self.default_callbacks(max_checkpoints=3)) trainer.train() # Now check the serialized files for prefix in ['model_state_epoch_*', 'training_state_epoch_*']: file_names = glob.glob(os.path.join(self.TEST_DIR, prefix)) epochs = [int(re.search(r"_([0-9])\.th", fname).group(1)) for fname in file_names] assert sorted(epochs) == [2, 3, 4]
def test_trainer_posts_to_url(self): url = 'http://slack.com?webhook=ewifjweoiwjef' responses.add(responses.POST, url) post_to_url = PostToUrl(url, message="only a test") callbacks = self.default_callbacks() + [post_to_url] trainer = CallbackTrainer(model=self.model, optimizer=self.optimizer, num_epochs=2, callbacks=callbacks) trainer.train() assert len(responses.calls) == 1 assert responses.calls[0].response.request.body == b'{"text": "only a test"}'
def test_trainer_saves_metrics_every_epoch(self): trainer = CallbackTrainer(model=self.model, optimizer=self.optimizer, num_epochs=5, serialization_dir=self.TEST_DIR, callbacks=self.default_callbacks(max_checkpoints=3)) trainer.train() for epoch in range(5): epoch_file = self.TEST_DIR / f'metrics_epoch_{epoch}.json' assert epoch_file.exists() metrics = json.load(open(epoch_file)) assert "validation_loss" in metrics assert "best_validation_loss" in metrics assert metrics.get("epoch") == epoch
def test_trainer_can_log_histograms(self): # enable activation logging for module in self.model.modules(): module.should_log_activations = True callbacks = [cb for cb in self.default_callbacks() if not isinstance(cb, LogToTensorboard)] # The lambda: None is unfortunate, but it will get replaced by the callback. tensorboard = TensorboardWriter(lambda: None, histogram_interval=2) callbacks.append(LogToTensorboard(tensorboard)) trainer = CallbackTrainer(self.model, self.optimizer, num_epochs=3, serialization_dir=self.TEST_DIR, callbacks=callbacks) trainer.train()
def test_handle_errors(self): class ErrorTest(Callback): """ A callback with three triggers * at BATCH_START, it raises a RuntimeError * at TRAINING_END, it sets a finished flag to True * at ERROR, it captures `trainer.exception` """ def __init__(self) -> None: self.exc: Optional[Exception] = None self.finished_training = None @handle_event(Events.BATCH_START) def raise_exception(self, trainer): raise RuntimeError("problem starting batch") @handle_event(Events.TRAINING_END) def finish_training(self, trainer): self.finished_training = True @handle_event(Events.ERROR) def capture_error(self, trainer): self.exc = trainer.exception error_test = ErrorTest() callbacks = self.default_callbacks() + [error_test] original_trainer = CallbackTrainer( self.model, self.instances, self.iterator, self.optimizer, callbacks=callbacks, num_epochs=1, serialization_dir=self.TEST_DIR, ) with pytest.raises(RuntimeError): original_trainer.train() # The callback should have captured the exception. assert error_test.exc is not None assert error_test.exc.args == ("problem starting batch",) # The "finished" flag should never have been set to True. assert not error_test.finished_training
def test_trainer_raises_on_model_with_no_loss_key(self): class FakeModel(Model): def forward(self, **kwargs): return {} with pytest.raises(RuntimeError): trainer = CallbackTrainer( FakeModel(None), training_data=self.instances, iterator=self.iterator, optimizer=self.optimizer, callbacks=self.default_callbacks(), num_epochs=2, serialization_dir=self.TEST_DIR, ) trainer.train()
def test_trainer_can_run_multiple_gpu(self): self.model.cuda() class MetaDataCheckWrapper(Model): """ Checks that the metadata field has been correctly split across the batch dimension when running on multiple gpus. """ def __init__(self, model): super().__init__(model.vocab) self.model = model def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore # pylint: disable=arguments-differ assert 'metadata' in kwargs and 'tags' in kwargs, \ f'tokens and metadata must be provided. Got {kwargs.keys()} instead.' batch_size = kwargs['tokens']['tokens'].size()[0] assert len(kwargs['metadata']) == batch_size, \ f'metadata must be split appropriately. Expected {batch_size} elements, ' \ f"got {len(kwargs['metadata'])} elements." return self.model.forward(**kwargs) multigpu_iterator = BasicIterator(batch_size=4) multigpu_iterator.index_with(self.vocab) trainer = CallbackTrainer(MetaDataCheckWrapper(self.model), self.optimizer, num_epochs=2, callbacks=self.default_callbacks(iterator=multigpu_iterator), cuda_device=[0, 1]) metrics = trainer.train() assert 'peak_cpu_memory_MB' in metrics assert isinstance(metrics['peak_cpu_memory_MB'], float) assert metrics['peak_cpu_memory_MB'] > 0 assert 'peak_gpu_0_memory_MB' in metrics assert isinstance(metrics['peak_gpu_0_memory_MB'], int) assert 'peak_gpu_1_memory_MB' in metrics assert isinstance(metrics['peak_gpu_1_memory_MB'], int)