def test_prepare_outputs_training_epoch_end_manual(batch_outputs, expected): """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook currently expects in the case of manual optimization.""" prepared = TrainingEpochLoop._prepare_outputs_training_epoch_end( batch_outputs, automatic=False, num_optimizers=-1, # does not matter for manual optimization ) assert prepared == expected
def test_prepare_outputs_training_epoch_end_automatic(num_optimizers, batch_outputs, expected): """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook currently expects in the case of automatic optimization.""" prepared = TrainingEpochLoop._prepare_outputs_training_epoch_end( batch_outputs, automatic=True, num_optimizers=num_optimizers, ) assert prepared == expected