Exemplo n.º 1
0
    def load_dataloader(self) -> None:
        """Loads data loader for training and test."""

        self.logger.info("Load dataset")

        if self.config.model == "snp":
            self.train_loader = torch.utils.data.DataLoader(
                npr.SequentialGPDataset(train=True,
                                        seq_len=20,
                                        **self.config.train_dataset_params),
                shuffle=True,
                batch_size=16)

            self.test_loader = torch.utils.data.DataLoader(
                npr.SequentialGPDataset(train=False,
                                        seq_len=20,
                                        **self.config.test_dataset_params),
                shuffle=False,
                batch_size=1)
        else:
            self.train_loader = torch.utils.data.DataLoader(npr.GPDataset(
                train=True, **self.config.train_dataset_params),
                                                            shuffle=True,
                                                            batch_size=16)

            self.test_loader = torch.utils.data.DataLoader(npr.GPDataset(
                train=False, **self.config.test_dataset_params),
                                                           shuffle=False,
                                                           batch_size=1)

        self.logger.info(f"Train dataset size: {len(self.train_loader)}")
        self.logger.info(f"Test dataset size: {len(self.test_loader)}")
Exemplo n.º 2
0
    def test_num_context(self):
        dataset = npr.GPDataset(train=True, **self.params)
        self.assertLessEqual(dataset.num_context,
                             self.params["num_context_max"])

        dataset = npr.GPDataset(train=False, **self.params)
        self.assertLessEqual(dataset.num_context,
                             self.params["num_context_max"])
Exemplo n.º 3
0
    def _small_case_target(self, train):
        indices = [0, 1, 2]
        dataset = npr.GPDataset(train=train, **self.params)

        num_context_max = self.params["num_context_max"]
        num_target_min = self.params["num_target_min"]
        x_dim = self.params["x_dim"]
        y_dim = self.params["y_dim"]

        dataset.generate_dataset()
        x_ctx, y_ctx, x_tgt, y_tgt = dataset[indices]

        self.assertEqual(x_ctx.size(0), len(indices))
        self.assertLessEqual(x_ctx.size(1), num_context_max)
        self.assertEqual(x_ctx.size(2), x_dim)

        self.assertEqual(y_ctx.size(0), len(indices))
        self.assertLessEqual(y_ctx.size(1), num_context_max)
        self.assertEqual(y_ctx.size(2), y_dim)

        self.assertEqual(x_tgt.size(0), len(indices))
        self.assertEqual(x_tgt.size(1), num_target_min)
        self.assertEqual(x_tgt.size(2), x_dim)

        self.assertEqual(y_tgt.size(0), len(indices))
        self.assertEqual(y_tgt.size(1), num_target_min)
        self.assertEqual(y_tgt.size(2), y_dim)
Exemplo n.º 4
0
    def test_generate_with_resample_params(self):
        dataset = npr.GPDataset(train=True, **self.params)
        dataset.generate_dataset(resample_params=True)

        indices = [0, 1, 2]
        x_ctx, y_ctx, x_tgt, y_tgt = dataset[indices]

        num_context_max = self.params["num_context_max"]
        num_target_max = self.params["num_target_max"]
        x_dim = self.params["x_dim"]
        y_dim = self.params["y_dim"]

        self.assertEqual(x_ctx.size(0), len(indices))
        self.assertLessEqual(x_ctx.size(1), num_context_max)
        self.assertEqual(x_ctx.size(2), x_dim)

        self.assertEqual(y_ctx.size(0), len(indices))
        self.assertLessEqual(y_ctx.size(1), num_context_max)
        self.assertEqual(y_ctx.size(2), y_dim)

        self.assertEqual(x_tgt.size(0), len(indices))
        self.assertLessEqual(x_tgt.size(1), num_target_max)
        self.assertEqual(x_tgt.size(2), x_dim)

        self.assertEqual(y_tgt.size(0), len(indices))
        self.assertLessEqual(y_tgt.size(1), num_target_max)
        self.assertEqual(y_tgt.size(2), y_dim)
Exemplo n.º 5
0
 def test_len(self):
     dataset = npr.GPDataset(train=True, **self.params)
     self.assertEqual(len(dataset), self.params["batch_size"])