Exemplo n.º 1
0
 def get_loaders(self, stage: str):
     dataset = TwoBlobsDataset()
     # dataset = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())
     sampler = CustomDistributedSampler(dataset=dataset, shuffle=True)
     loader = DataLoader(dataset,
                         batch_size=_BATCH_SIZE,
                         num_workers=_WORKERS,
                         sampler=sampler)
     return {"valid": loader}
Exemplo n.º 2
0
class MyConfigRunner(SupervisedConfigRunner):
    _dataset = TwoBlobsDataset()

    def get_datasets(self, *args, **kwargs):
        return {"valid": self._dataset}

    def handle_batch(self, batch):
        x, y = batch
        logits = self.model(x)

        self.batch = {"features": x, "targets": y.view(-1), "logits": logits}