コード例 #1
0
def test_connect_loops_recursive():
    """Test Trainer references in a nested loop assigned to a Trainer."""
    main_loop = NestedLoop()
    child0 = NestedLoop()
    child1 = NestedLoop()
    main_loop.connect(child0, child1)
    assert main_loop.trainer is None
    assert main_loop.child_loop0.trainer is None

    trainer = Trainer()
    trainer.fit_loop = main_loop
    assert child0.trainer is trainer
    assert child1.trainer is trainer
コード例 #2
0
def test_connect_loops_recursive():
    """Test Trainer references in a nested loop assigned to a Trainer."""
    main_loop = NestedLoop()
    child0 = NestedLoop()
    child1 = NestedLoop()
    main_loop.connect(child0, child1)

    with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"):
        _ = main_loop.trainer

    with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"):
        _ = main_loop.child_loop0.trainer

    trainer = Trainer()
    trainer.fit_loop = main_loop
    assert child0.trainer is trainer
    assert child1.trainer is trainer
コード例 #3
0
        self.val_acc(logits, y)
        self.log("val_acc", self.val_acc)
        self.log("val_loss", loss)


#############################################################################################
#                           Step 5 / 5: Connect the KFoldLoop to the Trainer                #
# After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to #
# the Trainer.                                                                              #
# Finally, use `trainer.fit` to start the cross validation training.                        #
#############################################################################################

if __name__ == "__main__":
    seed_everything(42)
    model = LitImageClassifier()
    datamodule = MNISTKFoldDataModule()
    trainer = Trainer(
        max_epochs=10,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        num_sanity_val_steps=0,
        devices=2,
        accelerator="auto",
        strategy="ddp",
    )
    internal_fit_loop = trainer.fit_loop
    trainer.fit_loop = KFoldLoop(5, export_path="./")
    trainer.fit_loop.connect(internal_fit_loop)
    trainer.fit(model, datamodule)