def test_shape_dataset_creation(shapes_image_size, shapes_nb_labels): """Create a Shapes dataset """ d = ShapeDataset(shapes_image_size) assert d.image_size == shapes_image_size assert d.get_nb_labels() == shapes_nb_labels assert d.get_nb_images() == 0
def test_shape_dataset_loading(shapes_image_size, shapes_nb_images, shapes_nb_labels, shapes_sample_config): """Load images into a Shapes dataset """ d = ShapeDataset(shapes_image_size) d.load(shapes_sample_config) assert d.get_nb_labels() == shapes_nb_labels assert d.get_nb_images() == shapes_nb_images
def test_shape_dataset_population(shapes_image_size, shapes_nb_images, shapes_nb_labels, shapes_config, shapes_temp_dir): """Populate a Shapes dataset """ d = ShapeDataset(shapes_image_size) d.populate(str(shapes_temp_dir), nb_images=shapes_nb_images) d.save(str(shapes_config)) assert d.get_nb_labels() == shapes_nb_labels assert d.get_nb_images() == shapes_nb_images assert os.path.isfile(str(shapes_config)) assert all( len(os.listdir(os.path.join(str(shapes_temp_dir), tmp_dir))) == shapes_nb_images for tmp_dir in ["images", "labels"])