Exemplo n.º 1
0
def test_all_paired_dataset_providers_should_provide_correct_labels(
        dataset_provider_cls_name):
    provider_cls = from_class_name(dataset_provider_cls_name)
    raw_data_provider = FakeRawDataProvider(curated=True)
    dataset_spec = DatasetSpec(raw_data_provider,
                               DatasetType.TEST,
                               with_excludes=False,
                               encoding=False)
    provider = provider_cls(raw_data_provider)

    dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100)
    for batch in dataset:
        left_img = np.rint(
            (batch[0][consts.LEFT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5) *
            10)
        right_img = np.rint(
            (batch[0][consts.RIGHT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5) *
            10)
        pair_label = batch[1][consts.TFRECORD_PAIR_LABEL].numpy().flatten()[0]
        left_label = batch[1][consts.TFRECORD_LEFT_LABEL].numpy().flatten()[0]
        right_label = batch[1][
            consts.TFRECORD_RIGHT_LABEL].numpy().flatten()[0]

        assert pair_label == (1 if left_img == right_img else 0)
        assert pair_label == (1 if left_label == right_label else 0)
Exemplo n.º 2
0
def test_all_unpaired_dataset_providers_should_honor_excludes(
        dataset_provider_cls_name, patched_excluded):
    provider_cls = from_class_name(dataset_provider_cls_name)
    raw_data_provider = FakeRawDataProvider(curated=True)
    dataset_spec = DatasetSpec(raw_data_provider,
                               DatasetType.TRAIN,
                               with_excludes=False,
                               encoding=False,
                               paired=False)

    provider = provider_cls(raw_data_provider)

    dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100)
    encountered_labels = set()
    for batch in dataset:
        image = np.rint(
            (batch[0][consts.FEATURES].numpy().flatten()[0] + 0.5) * 10)
        label = batch[1][consts.LABELS].numpy().flatten()[0]

        encountered_labels.update((image, ))
        encountered_labels.update((label, ))
    assert_that((np.rint(list(encountered_labels))),
                only_contains(not_(is_in(list(patched_excluded.numpy())))))
    assert_that((np.rint(list(encountered_labels))),
                only_contains((is_in([0, 1, 4]))))
Exemplo n.º 3
0
def test_all_unpaired_dataset_providers_should_provide_correct_labels(
        dataset_provider_cls_name):
    provider_cls = from_class_name(dataset_provider_cls_name)
    raw_data_provider = FakeRawDataProvider(curated=True)
    dataset_spec = DatasetSpec(raw_data_provider,
                               DatasetType.TRAIN,
                               with_excludes=False,
                               encoding=False,
                               paired=False)
    provider = provider_cls(raw_data_provider)

    dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100)
    for batch in dataset:
        images = np.rint(
            (batch[0][consts.FEATURES].numpy().flatten()[0] + 0.5) * 10)
        labels = batch[1][consts.LABELS].numpy().flatten()[0]

        assert images == labels
Exemplo n.º 4
0
def dataset_spec(
    description: DataDescription = None,
    raw_dataset_fragment: RawDatasetFragment = None,
    type=DatasetType.TRAIN,
    with_excludes=False,
    encoding=True,
    paired=True,
    repeating_pairs=True,
    identical_pairs=False,
):
    from testing_utils.testing_classes import FakeRawDataProvider
    return DatasetSpec(raw_data_provider=FakeRawDataProvider(
        description, raw_dataset_fragment, curated=True),
                       type=type,
                       with_excludes=with_excludes,
                       encoding=encoding,
                       paired=paired,
                       repeating_pairs=repeating_pairs,
                       identical_pairs=identical_pairs)
Exemplo n.º 5
0
def test_all_unpaired_dataset_providers_should_get_features_from_raw_data_provider(
        description, dataset_provider_cls_name):
    provider = dataset_provider_cls_name(
        FakeRawDataProvider(curated=True, description=description))

    image_dims = provider.raw_data_provider.description.image_dimensions
    batch_size = 12
    dataset_spec = gen.dataset_spec(description=description,
                                    type=DatasetType.TEST,
                                    with_excludes=False,
                                    encoding=False,
                                    paired=False)
    dataset = provider.supply_dataset(dataset_spec,
                                      batch_size=batch_size).take(100)
    images, labels = tf_helpers.unpack_first_batch(dataset)

    assert images.shape == (batch_size, *image_dims)

    assert labels.shape == (batch_size, )
Exemplo n.º 6
0
def test_all_paired_dataset_providers_should_honor_excludes(
        dataset_provider_cls_name, patched_excluded):
    provider_cls = from_class_name(dataset_provider_cls_name)
    raw_data_provider = FakeRawDataProvider(curated=True)
    dataset_spec = DatasetSpec(raw_data_provider,
                               DatasetType.TEST,
                               with_excludes=False,
                               encoding=False)
    provider = provider_cls(raw_data_provider)

    dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100)
    encountered_labels = set()
    for batch in dataset:
        left_label = batch[0][
            consts.LEFT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5
        right_label = batch[0][
            consts.RIGHT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5
        encountered_labels.update((left_label, right_label))
    assert_that((np.rint(list(encountered_labels)) * 10),
                only_contains(not_(is_in(list(patched_excluded.numpy())))))
    assert_that((np.rint(list(encountered_labels)) * 10),
                only_contains((is_in([0, 1, 4]))))
Exemplo n.º 7
0
from testing_utils import tf_helpers, gen
from testing_utils.testing_classes import FakeRawDataProvider, FAKE_DATA_DESCRIPTION, FAKE_MNIST_DESCRIPTION
from testing_utils.tf_helpers import run_eagerly


class FakeDatasetProvider(AbstractDatasetProvider):
    def build_dataset(self, dataset_spec: DatasetSpec) -> tf.data.Dataset:
        pass


class FakeTrainUnpairedDatasetProvider(TFRecordTrainUnpairedDatasetProvider):
    def build_dataset(self, dataset_spec: DatasetSpec) -> tf.data.Dataset:
        pass


raw_data_provider = FakeRawDataProvider()
default_provider = FakeDatasetProvider(raw_data_provider)
train_unpaired_dataset_provider = FakeTrainUnpairedDatasetProvider(
    raw_data_provider)


@pytest.fixture()
def patched_dataset_building(mocker):
    return mocker.patch.object(default_provider,
                               'build_dataset',
                               autospec=True)


def test_train_input_fn_should_search_for_dataset_with_correct_spec(
        patched_dataset_building):
    default_provider.train_input_fn()
Exemplo n.º 8
0
 def raw_data_provider(self) -> AbstractRawDataProvider:
     return FakeRawDataProvider(description=self.description)