def test_all_paired_dataset_providers_should_provide_correct_labels( dataset_provider_cls_name): provider_cls = from_class_name(dataset_provider_cls_name) raw_data_provider = FakeRawDataProvider(curated=True) dataset_spec = DatasetSpec(raw_data_provider, DatasetType.TEST, with_excludes=False, encoding=False) provider = provider_cls(raw_data_provider) dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100) for batch in dataset: left_img = np.rint( (batch[0][consts.LEFT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5) * 10) right_img = np.rint( (batch[0][consts.RIGHT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5) * 10) pair_label = batch[1][consts.TFRECORD_PAIR_LABEL].numpy().flatten()[0] left_label = batch[1][consts.TFRECORD_LEFT_LABEL].numpy().flatten()[0] right_label = batch[1][ consts.TFRECORD_RIGHT_LABEL].numpy().flatten()[0] assert pair_label == (1 if left_img == right_img else 0) assert pair_label == (1 if left_label == right_label else 0)
def test_all_unpaired_dataset_providers_should_honor_excludes( dataset_provider_cls_name, patched_excluded): provider_cls = from_class_name(dataset_provider_cls_name) raw_data_provider = FakeRawDataProvider(curated=True) dataset_spec = DatasetSpec(raw_data_provider, DatasetType.TRAIN, with_excludes=False, encoding=False, paired=False) provider = provider_cls(raw_data_provider) dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100) encountered_labels = set() for batch in dataset: image = np.rint( (batch[0][consts.FEATURES].numpy().flatten()[0] + 0.5) * 10) label = batch[1][consts.LABELS].numpy().flatten()[0] encountered_labels.update((image, )) encountered_labels.update((label, )) assert_that((np.rint(list(encountered_labels))), only_contains(not_(is_in(list(patched_excluded.numpy()))))) assert_that((np.rint(list(encountered_labels))), only_contains((is_in([0, 1, 4]))))
def test_all_unpaired_dataset_providers_should_provide_correct_labels( dataset_provider_cls_name): provider_cls = from_class_name(dataset_provider_cls_name) raw_data_provider = FakeRawDataProvider(curated=True) dataset_spec = DatasetSpec(raw_data_provider, DatasetType.TRAIN, with_excludes=False, encoding=False, paired=False) provider = provider_cls(raw_data_provider) dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100) for batch in dataset: images = np.rint( (batch[0][consts.FEATURES].numpy().flatten()[0] + 0.5) * 10) labels = batch[1][consts.LABELS].numpy().flatten()[0] assert images == labels
def dataset_spec( description: DataDescription = None, raw_dataset_fragment: RawDatasetFragment = None, type=DatasetType.TRAIN, with_excludes=False, encoding=True, paired=True, repeating_pairs=True, identical_pairs=False, ): from testing_utils.testing_classes import FakeRawDataProvider return DatasetSpec(raw_data_provider=FakeRawDataProvider( description, raw_dataset_fragment, curated=True), type=type, with_excludes=with_excludes, encoding=encoding, paired=paired, repeating_pairs=repeating_pairs, identical_pairs=identical_pairs)
def test_all_unpaired_dataset_providers_should_get_features_from_raw_data_provider( description, dataset_provider_cls_name): provider = dataset_provider_cls_name( FakeRawDataProvider(curated=True, description=description)) image_dims = provider.raw_data_provider.description.image_dimensions batch_size = 12 dataset_spec = gen.dataset_spec(description=description, type=DatasetType.TEST, with_excludes=False, encoding=False, paired=False) dataset = provider.supply_dataset(dataset_spec, batch_size=batch_size).take(100) images, labels = tf_helpers.unpack_first_batch(dataset) assert images.shape == (batch_size, *image_dims) assert labels.shape == (batch_size, )
def test_all_paired_dataset_providers_should_honor_excludes( dataset_provider_cls_name, patched_excluded): provider_cls = from_class_name(dataset_provider_cls_name) raw_data_provider = FakeRawDataProvider(curated=True) dataset_spec = DatasetSpec(raw_data_provider, DatasetType.TEST, with_excludes=False, encoding=False) provider = provider_cls(raw_data_provider) dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100) encountered_labels = set() for batch in dataset: left_label = batch[0][ consts.LEFT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5 right_label = batch[0][ consts.RIGHT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5 encountered_labels.update((left_label, right_label)) assert_that((np.rint(list(encountered_labels)) * 10), only_contains(not_(is_in(list(patched_excluded.numpy()))))) assert_that((np.rint(list(encountered_labels)) * 10), only_contains((is_in([0, 1, 4]))))
from testing_utils import tf_helpers, gen from testing_utils.testing_classes import FakeRawDataProvider, FAKE_DATA_DESCRIPTION, FAKE_MNIST_DESCRIPTION from testing_utils.tf_helpers import run_eagerly class FakeDatasetProvider(AbstractDatasetProvider): def build_dataset(self, dataset_spec: DatasetSpec) -> tf.data.Dataset: pass class FakeTrainUnpairedDatasetProvider(TFRecordTrainUnpairedDatasetProvider): def build_dataset(self, dataset_spec: DatasetSpec) -> tf.data.Dataset: pass raw_data_provider = FakeRawDataProvider() default_provider = FakeDatasetProvider(raw_data_provider) train_unpaired_dataset_provider = FakeTrainUnpairedDatasetProvider( raw_data_provider) @pytest.fixture() def patched_dataset_building(mocker): return mocker.patch.object(default_provider, 'build_dataset', autospec=True) def test_train_input_fn_should_search_for_dataset_with_correct_spec( patched_dataset_building): default_provider.train_input_fn()
def raw_data_provider(self) -> AbstractRawDataProvider: return FakeRawDataProvider(description=self.description)