Ejemplo n.º 1
0
    def testDatasetSizeNoneValidation(self, split):
        batch_size = 9 if split == tfds.Split.TRAIN else 5
        if split == tfds.Split.VALIDATION:
            with self.assertRaises(ValueError):
                dataset_builder = movielens.MovieLensDataset(
                    split,
                    validation_percent=0,
                    test_percent=0.2,
                    shuffle_buffer_size=20)
        else:
            dataset_builder = movielens.MovieLensDataset(
                split,
                validation_percent=0,
                test_percent=0.2,
                shuffle_buffer_size=20)
            dataset = dataset_builder.load(batch_size=batch_size).take(1)
            element = next(iter(dataset))

            feature_name_arr = [
                'timestamp', 'movie_id', 'movie_title', 'user_id',
                'user_gender', 'bucketized_user_age', 'user_occupation_label',
                'user_occupation_text', 'user_zip_code', 'labels'
            ]

            for name in feature_name_arr:
                feature = element[name]
                self.assertEqual(feature.shape[0], batch_size)
    def testDatasetSizeNoneValidation(self, split):
        batch_size = 9
        eval_batch_size = 5

        dataset_builder = movielens.MovieLensDataset(
            batch_size=batch_size,
            eval_batch_size=eval_batch_size,
            validation_percent=0,
            test_percent=0.2,
            shuffle_buffer_size=20)

        if split == base.Split.VAL:
            with self.assertRaises(ValueError):
                dataset = dataset_builder.build(split).take(1)
        else:
            dataset = dataset_builder.build(split).take(1)
            element = next(iter(dataset))

            movie_id = element['movie_id']
            user_id = element['user_id']
            labels = element['labels']

            expected_batch_size = (batch_size if split == base.Split.TRAIN else
                                   eval_batch_size)

            self.assertEqual(movie_id.shape[0], expected_batch_size)
            self.assertEqual(user_id.shape[0], expected_batch_size)
            self.assertEqual(labels.shape[0], expected_batch_size)
    def testDatasetSize(self, split):
        batch_size = 9
        eval_batch_size = 5

        dataset_builder = movielens.MovieLensDataset(
            batch_size=batch_size,
            eval_batch_size=eval_batch_size,
            validation_percent=0.1,
            test_percent=0.2,
            shuffle_buffer_size=20)
        dataset = dataset_builder.build(split).take(1)
        element = next(iter(dataset))

        feature_name_arr = [
            'timestamp', 'movie_id', 'movie_title', 'user_id', 'user_gender',
            'bucketized_user_age', 'user_occupation_label',
            'user_occupation_text', 'user_zip_code', 'labels'
        ]

        for name in feature_name_arr:
            feature = element[name]
            expected_batch_size = (batch_size if split == base.Split.TRAIN else
                                   eval_batch_size)
            self.assertEqual(feature.shape[0], expected_batch_size)