def predict_original_samples(self, batch, conv_type, output): """ Takes the output generated by the NN and upsamples it to the original data Arguments: batch -- processed batch conv_type -- Type of convolutio (DENSE, PARTIAL_DENSE, etc...) output -- output predicted by the model """ full_res_results = {} num_sample = BaseDataset.get_num_samples(batch, conv_type) if conv_type == "DENSE": output = output.reshape(num_sample, -1, output.shape[-1]) # [B,N,L] setattr(batch, "_pred", output) for b in range(num_sample): sampleid = batch.sampleid[b] sample_raw_pos = self.test_dataset[0].get_raw(sampleid).pos.to( output.device) predicted = BaseDataset.get_sample(batch, "_pred", b, conv_type) origindid = BaseDataset.get_sample(batch, SaveOriginalPosId.KEY, b, conv_type) full_prediction = knn_interpolate(predicted, sample_raw_pos[origindid], sample_raw_pos, k=3) labels = full_prediction.max(1)[1].unsqueeze(-1) full_res_results[self.test_dataset[0].get_filename( sampleid)] = np.hstack(( sample_raw_pos.cpu().numpy(), labels.cpu().numpy(), )) return full_res_results
def test_multiple_test_datasets(self): opt = Options() opt.dataset_name = os.path.join(os.getcwd(), "test") opt.dataroot = os.path.join(os.getcwd(), "test") class MultiTestDataset(BaseDataset): def __init__(self, dataset_opt): super(MultiTestDataset, self).__init__(dataset_opt) self.train_dataset = CustomMockDataset(10, 1, 3, 10) self.val_dataset = CustomMockDataset(10, 1, 3, 10) self.test_dataset = [ CustomMockDataset(10, 1, 3, 10), CustomMockDataset(10, 1, 3, 20) ] dataset = MultiTestDataset(opt) model_config = MockModelConfig() model_config.conv_type = "dense" model = MockModel(model_config) dataset.create_dataloaders(model, 5, True, 0, False) loaders = dataset.test_dataloaders self.assertEqual(len(loaders), 2) self.assertEqual(len(loaders[0].dataset), 10) self.assertEqual(len(loaders[1].dataset), 20) self.assertEqual(dataset.num_classes, 3) self.assertEqual(dataset.is_hierarchical, False) self.assertEqual(dataset.has_fixed_points_transform, False) self.assertEqual(dataset.has_val_loader, True) self.assertEqual(dataset.class_to_segments, None) self.assertEqual(dataset.feature_dimension, 1) batch = next(iter(loaders[0])) num_samples = BaseDataset.get_num_samples(batch, "dense") self.assertEqual(num_samples, 5) sample = BaseDataset.get_sample(batch, "pos", 1, "dense") self.assertEqual(sample.shape, (10, 3)) sample = BaseDataset.get_sample(batch, "x", 1, "dense") self.assertEqual(sample.shape, (10, 1)) self.assertEqual(dataset.num_batches, { "train": 2, "val": 2, "test_0": 2, "test_1": 4 }) repr = "Dataset: MultiTestDataset \n\x1b[0;95mpre_transform \x1b[0m= None\n\x1b[0;95mtest_transform \x1b[0m= None\n\x1b[0;95mtrain_transform \x1b[0m= None\n\x1b[0;95mval_transform \x1b[0m= None\n\x1b[0;95minference_transform \x1b[0m= None\nSize of \x1b[0;95mtrain_dataset \x1b[0m= 10\nSize of \x1b[0;95mtest_dataset \x1b[0m= 10, 20\nSize of \x1b[0;95mval_dataset \x1b[0m= 10\n\x1b[0;95mBatch size =\x1b[0m 5" self.assertEqual(dataset.__repr__(), repr)