Esempio n. 1
0
 def test_train_sanity_fit(self):
     for model_name in supported_tv_models:
         model = cnn.create_vision_cnn(model_name, 10, pretrained=None)
         opt = torch.optim.Adam(model.parameters(), lr=1e-3)
         loss = nn.CrossEntropyLoss()
         res = cnn.train_sanity_fit(model,
                                    train_loader,
                                    loss,
                                    "cpu",
                                    num_batches=10)
         self.assertTrue(res)
Esempio n. 2
0
    def test_csv_single_label_dataset(self):
        complete_dataset = CSVSingleLabelDataset(df, data_dir, "Image",
                                                 "Label", tfms, "png")
        self.assertTrue(complete_dataset[0])

        train_loader = torch.utils.data.DataLoader(complete_dataset,
                                                   num_workers=1)
        model = cnn.create_cnn("resnet18", 2, pretrained=None)
        opt = torch.optim.Adam(model.parameters(), lr=1e-3)
        loss = nn.CrossEntropyLoss()
        res = cnn.train_sanity_fit(model,
                                   train_loader,
                                   loss,
                                   "cpu",
                                   num_batches=1)
        self.assertTrue(res)