def load_dataloader(self) -> None: """Loads data loader for training and test.""" self.logger.info("Load dataset") if self.config.model == "snp": self.train_loader = torch.utils.data.DataLoader( npr.SequentialGPDataset(train=True, seq_len=20, **self.config.train_dataset_params), shuffle=True, batch_size=16) self.test_loader = torch.utils.data.DataLoader( npr.SequentialGPDataset(train=False, seq_len=20, **self.config.test_dataset_params), shuffle=False, batch_size=1) else: self.train_loader = torch.utils.data.DataLoader(npr.GPDataset( train=True, **self.config.train_dataset_params), shuffle=True, batch_size=16) self.test_loader = torch.utils.data.DataLoader(npr.GPDataset( train=False, **self.config.test_dataset_params), shuffle=False, batch_size=1) self.logger.info(f"Train dataset size: {len(self.train_loader)}") self.logger.info(f"Test dataset size: {len(self.test_loader)}")
def test_num_context(self): dataset = npr.GPDataset(train=True, **self.params) self.assertLessEqual(dataset.num_context, self.params["num_context_max"]) dataset = npr.GPDataset(train=False, **self.params) self.assertLessEqual(dataset.num_context, self.params["num_context_max"])
def _small_case_target(self, train): indices = [0, 1, 2] dataset = npr.GPDataset(train=train, **self.params) num_context_max = self.params["num_context_max"] num_target_min = self.params["num_target_min"] x_dim = self.params["x_dim"] y_dim = self.params["y_dim"] dataset.generate_dataset() x_ctx, y_ctx, x_tgt, y_tgt = dataset[indices] self.assertEqual(x_ctx.size(0), len(indices)) self.assertLessEqual(x_ctx.size(1), num_context_max) self.assertEqual(x_ctx.size(2), x_dim) self.assertEqual(y_ctx.size(0), len(indices)) self.assertLessEqual(y_ctx.size(1), num_context_max) self.assertEqual(y_ctx.size(2), y_dim) self.assertEqual(x_tgt.size(0), len(indices)) self.assertEqual(x_tgt.size(1), num_target_min) self.assertEqual(x_tgt.size(2), x_dim) self.assertEqual(y_tgt.size(0), len(indices)) self.assertEqual(y_tgt.size(1), num_target_min) self.assertEqual(y_tgt.size(2), y_dim)
def test_generate_with_resample_params(self): dataset = npr.GPDataset(train=True, **self.params) dataset.generate_dataset(resample_params=True) indices = [0, 1, 2] x_ctx, y_ctx, x_tgt, y_tgt = dataset[indices] num_context_max = self.params["num_context_max"] num_target_max = self.params["num_target_max"] x_dim = self.params["x_dim"] y_dim = self.params["y_dim"] self.assertEqual(x_ctx.size(0), len(indices)) self.assertLessEqual(x_ctx.size(1), num_context_max) self.assertEqual(x_ctx.size(2), x_dim) self.assertEqual(y_ctx.size(0), len(indices)) self.assertLessEqual(y_ctx.size(1), num_context_max) self.assertEqual(y_ctx.size(2), y_dim) self.assertEqual(x_tgt.size(0), len(indices)) self.assertLessEqual(x_tgt.size(1), num_target_max) self.assertEqual(x_tgt.size(2), x_dim) self.assertEqual(y_tgt.size(0), len(indices)) self.assertLessEqual(y_tgt.size(1), num_target_max) self.assertEqual(y_tgt.size(2), y_dim)
def test_len(self): dataset = npr.GPDataset(train=True, **self.params) self.assertEqual(len(dataset), self.params["batch_size"])