Esempio n. 1
0
 def test_synthetic_random_data(self):
     ds = DataGenerator.get_random_dataset(height=32,
                                           width=32,
                                           num_classes=5,
                                           data_type=tf.float32)
     assert (DataGenerator.evaluate_size_dataset(ds) == 1)
     assert (isinstance(ds, tf.data.Dataset))
Esempio n. 2
0
 def test_check_size(self):
     ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
     assert (DataGenerator.evaluate_size_dataset(ds) == 3)
Esempio n. 3
0
    args = parser.parse_args()
    logging.info(f'args = {args}')

    dataset_name = args.dataset
    dataset_path = args.dataset_path

    ds_train, _, ds_size, _, _ = DatasetFactory.get_dataset(
        dataset_name=dataset_name,
        dataset_path=dataset_path,
        split='train',
        img_datatype=float,
        micro_batch_size=1,
        accelerator_side_preprocess=False,
        apply_preprocessing=False)

    train_split_match = DataGenerator.evaluate_size_dataset(
        ds_train) == ds_size

    ds_valid, _, ds_valid_size, _, _ = DatasetFactory.get_dataset(
        dataset_name=dataset_name,
        dataset_path=dataset_path,
        split='test',
        img_datatype=float,
        micro_batch_size=1,
        accelerator_side_preprocess=False,
        apply_preprocessing=False)

    test_split_match = DataGenerator.evaluate_size_dataset(
        ds_valid) == ds_valid_size

    if not train_split_match:
        logging.warning(