def test_connect_loops_recursive(): """Test Trainer references in a nested loop assigned to a Trainer.""" main_loop = NestedLoop() child0 = NestedLoop() child1 = NestedLoop() main_loop.connect(child0, child1) assert main_loop.trainer is None assert main_loop.child_loop0.trainer is None trainer = Trainer() trainer.fit_loop = main_loop assert child0.trainer is trainer assert child1.trainer is trainer
def test_connect_loops_recursive(): """Test Trainer references in a nested loop assigned to a Trainer.""" main_loop = NestedLoop() child0 = NestedLoop() child1 = NestedLoop() main_loop.connect(child0, child1) with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): _ = main_loop.trainer with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): _ = main_loop.child_loop0.trainer trainer = Trainer() trainer.fit_loop = main_loop assert child0.trainer is trainer assert child1.trainer is trainer
self.val_acc(logits, y) self.log("val_acc", self.val_acc) self.log("val_loss", loss) ############################################################################################# # Step 5 / 5: Connect the KFoldLoop to the Trainer # # After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to # # the Trainer. # # Finally, use `trainer.fit` to start the cross validation training. # ############################################################################################# if __name__ == "__main__": seed_everything(42) model = LitImageClassifier() datamodule = MNISTKFoldDataModule() trainer = Trainer( max_epochs=10, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, num_sanity_val_steps=0, devices=2, accelerator="auto", strategy="ddp", ) internal_fit_loop = trainer.fit_loop trainer.fit_loop = KFoldLoop(5, export_path="./") trainer.fit_loop.connect(internal_fit_loop) trainer.fit(model, datamodule)