Beispiel #1
0
 def test_train_step(self):
     for model_name in supported_tv_models:
         model = cnn.create_vision_cnn(model_name, 10, pretrained=None)
         opt = torch.optim.Adam(model.parameters(), lr=1e-3)
         loss = nn.CrossEntropyLoss()
         # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=opt, base_lr=1e-4, max_lr=1e-3, mode="min")
         train_metrics = cnn.train_step(model,
                                        train_loader,
                                        loss,
                                        "cpu",
                                        opt,
                                        num_batches=10,
                                        grad_penalty=True)
         self.assertIsInstance(train_metrics, Dict)
         exp_keys = ("loss", "top1", "top5")
         for exp_k in exp_keys:
             self.assertTrue(exp_k in train_metrics.keys())
    else:
        print("Model Created. Training on CPU only")
        device = "cpu"

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    criterion = (nn.CrossEntropyLoss()
                 )  # All classification problems we need Cross entropy loss
    early_stopper = utils.EarlyStopping(patience=7,
                                        verbose=True,
                                        path=SAVE_PATH)

    for epoch in tqdm(range(EPOCHS)):
        print()
        print(f"Training Epoch = {epoch}")
        train_metrics = cnn.train_step(model, train_loader, criterion, device,
                                       optimizer)
        print()

        print(f"Validating Epoch = {epoch}")
        valid_metrics = cnn.val_step(model, valid_loader, criterion, device)

        validation_loss = valid_metrics["loss"]
        early_stopper(validation_loss, model=model)

        if early_stopper.early_stop:
            print("Saving Model and Early Stopping")
            print("Early Stopping. Ran out of Patience for validation loss")
            break

        print("Done Training, Model Saved to Disk")