コード例 #1
0
def test_all_unpaired_dataset_providers_should_honor_excludes(
        dataset_provider_cls_name, patched_excluded):
    provider_cls = from_class_name(dataset_provider_cls_name)
    raw_data_provider = FakeRawDataProvider(curated=True)
    dataset_spec = DatasetSpec(raw_data_provider,
                               DatasetType.TRAIN,
                               with_excludes=False,
                               encoding=False,
                               paired=False)

    provider = provider_cls(raw_data_provider)

    dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100)
    encountered_labels = set()
    for batch in dataset:
        image = np.rint(
            (batch[0][consts.FEATURES].numpy().flatten()[0] + 0.5) * 10)
        label = batch[1][consts.LABELS].numpy().flatten()[0]

        encountered_labels.update((image, ))
        encountered_labels.update((label, ))
    assert_that((np.rint(list(encountered_labels))),
                only_contains(not_(is_in(list(patched_excluded.numpy())))))
    assert_that((np.rint(list(encountered_labels))),
                only_contains((is_in([0, 1, 4]))))
コード例 #2
0
def test_all_paired_dataset_providers_should_provide_correct_labels(
        dataset_provider_cls_name):
    provider_cls = from_class_name(dataset_provider_cls_name)
    raw_data_provider = FakeRawDataProvider(curated=True)
    dataset_spec = DatasetSpec(raw_data_provider,
                               DatasetType.TEST,
                               with_excludes=False,
                               encoding=False)
    provider = provider_cls(raw_data_provider)

    dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100)
    for batch in dataset:
        left_img = np.rint(
            (batch[0][consts.LEFT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5) *
            10)
        right_img = np.rint(
            (batch[0][consts.RIGHT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5) *
            10)
        pair_label = batch[1][consts.TFRECORD_PAIR_LABEL].numpy().flatten()[0]
        left_label = batch[1][consts.TFRECORD_LEFT_LABEL].numpy().flatten()[0]
        right_label = batch[1][
            consts.TFRECORD_RIGHT_LABEL].numpy().flatten()[0]

        assert pair_label == (1 if left_img == right_img else 0)
        assert pair_label == (1 if left_label == right_label else 0)
コード例 #3
0
 def eval_with_excludes_input_fn(self) -> Dataset:
     utils.log('Creating eval_input_fn with excluded elements')
     test_ignoring_excludes = DatasetSpec(
         raw_data_provider=self.raw_data_provider,
         type=DatasetType.TEST,
         with_excludes=True,
         encoding=self.is_encoded())
     return self.supply_dataset(dataset_spec=test_ignoring_excludes,
                                batch_size=self.calculate_batch_size())
コード例 #4
0
 def eval_input_fn(self) -> Dataset:
     utils.log('Creating eval_input_fn')
     test_data_config = DatasetSpec(
         raw_data_provider=self.raw_data_provider,
         type=DatasetType.TEST,
         with_excludes=False,
         encoding=self.is_encoded())
     return self.supply_dataset(dataset_spec=test_data_config,
                                batch_size=self.calculate_batch_size())
コード例 #5
0
 def train_input_fn(self) -> Dataset:
     utils.log('Creating train_input_fn')
     train_data_config = DatasetSpec(
         raw_data_provider=self.raw_data_provider,
         type=DatasetType.TRAIN,
         with_excludes=False,
         encoding=self.is_encoded(),
         paired=self.is_train_paired())
     return self.supply_dataset(
         dataset_spec=train_data_config,
         shuffle_buffer_size=config[consts.SHUFFLE_BUFFER_SIZE],
         batch_size=config[consts.BATCH_SIZE],
         repeat=True)
コード例 #6
0
 def infer(self, take_num: int) -> Dataset:
     utils.log('Creating infer_fn')
     test_with_excludes = DatasetSpec(
         raw_data_provider=self.raw_data_provider,
         type=DatasetType.TEST,
         with_excludes=True,
         encoding=self.is_encoded())
     return self.supply_dataset(
         dataset_spec=test_with_excludes,
         batch_size=take_num,
         repeat=False,
         shuffle_buffer_size=config[consts.SHUFFLE_BUFFER_SIZE],
         prefetch=False,
         take_num=take_num)
コード例 #7
0
def test_all_unpaired_dataset_providers_should_provide_correct_labels(
        dataset_provider_cls_name):
    provider_cls = from_class_name(dataset_provider_cls_name)
    raw_data_provider = FakeRawDataProvider(curated=True)
    dataset_spec = DatasetSpec(raw_data_provider,
                               DatasetType.TRAIN,
                               with_excludes=False,
                               encoding=False,
                               paired=False)
    provider = provider_cls(raw_data_provider)

    dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100)
    for batch in dataset:
        images = np.rint(
            (batch[0][consts.FEATURES].numpy().flatten()[0] + 0.5) * 10)
        labels = batch[1][consts.LABELS].numpy().flatten()[0]

        assert images == labels
コード例 #8
0
def dataset_spec(
    description: DataDescription = None,
    raw_dataset_fragment: RawDatasetFragment = None,
    type=DatasetType.TRAIN,
    with_excludes=False,
    encoding=True,
    paired=True,
    repeating_pairs=True,
    identical_pairs=False,
):
    from testing_utils.testing_classes import FakeRawDataProvider
    return DatasetSpec(raw_data_provider=FakeRawDataProvider(
        description, raw_dataset_fragment, curated=True),
                       type=type,
                       with_excludes=with_excludes,
                       encoding=encoding,
                       paired=paired,
                       repeating_pairs=repeating_pairs,
                       identical_pairs=identical_pairs)
コード例 #9
0
def test_train_input_fn_should_correct_configure_dataset(
        mocker, dataset_provider, paired):
    patched_dataset_supplying = mocker.patch.object(dataset_provider,
                                                    'supply_dataset',
                                                    autospec=True)

    dataset_provider.train_input_fn()

    patched_dataset_supplying.assert_called_once_with(
        DatasetSpec(raw_data_provider,
                    DatasetType.TRAIN,
                    with_excludes=False,
                    encoding=True,
                    repeating_pairs=True,
                    identical_pairs=False,
                    paired=paired),
        shuffle_buffer_size=config[consts.SHUFFLE_BUFFER_SIZE],
        batch_size=config[consts.BATCH_SIZE],
        repeat=True)
コード例 #10
0
def test_test_eval_with_excludes_input_fn_should_correct_configure_dataset(
        mocker, dataset_provider):
    patched_dataset_supplying = mocker.patch.object(dataset_provider,
                                                    'supply_dataset',
                                                    autospec=True)

    dataset_provider.eval_with_excludes_input_fn()
    if dataset_provider.is_train_paired():
        expected_batch_size = config[consts.BATCH_SIZE]
    else:
        expected_batch_size = config[consts.BATCH_SIZE] / 2
    patched_dataset_supplying.assert_called_once_with(
        DatasetSpec(raw_data_provider,
                    DatasetType.TEST,
                    with_excludes=True,
                    encoding=True,
                    repeating_pairs=True,
                    identical_pairs=False,
                    paired=True),
        batch_size=expected_batch_size)
コード例 #11
0
ファイル: image_summaries.py プロジェクト: arozans/idenface
def create_pair_summaries(run_data: RunData):
    dataset_provider_cls = run_data.model.raw_data_provider
    tf.reset_default_graph()
    batch_size = 10
    utils.log('Creating {} sample features summaries'.format(batch_size))
    dataset: tf.data.Dataset = run_data.model.dataset_provider.supply_dataset(
        dataset_spec=DatasetSpec(
            dataset_provider_cls,
            DatasetType.TEST,
            with_excludes=False,
            encoding=run_data.model.dataset_provider.is_encoded()),
        shuffle_buffer_size=10000,
        batch_size=batch_size,
        prefetch=False)
    iterator = dataset.make_one_shot_iterator()
    iterator = iterator.get_next()
    with tf.Session() as sess:
        left = iterator[0][consts.LEFT_FEATURE_IMAGE]
        right = iterator[0][consts.RIGHT_FEATURE_IMAGE]
        pair_labels = iterator[1][consts.PAIR_LABEL]
        left_labels = iterator[1][consts.LEFT_FEATURE_LABEL]
        right_labels = iterator[1][consts.RIGHT_FEATURE_LABEL]
        pairs_imgs_summary = create_pair_summary(
            left, right, pair_labels, left_labels, right_labels,
            dataset_provider_cls.description)

        image_summary = tf.summary.image('paired_images',
                                         pairs_imgs_summary,
                                         max_outputs=batch_size)
        all_summaries = tf.summary.merge_all()

        dir = filenames.get_run_logs_data_dir(run_data) / 'features'
        dir.mkdir(exist_ok=True, parents=True)
        writer = tf.summary.FileWriter(str(dir), sess.graph)

        sess.run(tf.global_variables_initializer())

        summary = sess.run(all_summaries)
        writer.add_summary(summary)
        writer.flush()
コード例 #12
0
def test_all_paired_dataset_providers_should_honor_excludes(
        dataset_provider_cls_name, patched_excluded):
    provider_cls = from_class_name(dataset_provider_cls_name)
    raw_data_provider = FakeRawDataProvider(curated=True)
    dataset_spec = DatasetSpec(raw_data_provider,
                               DatasetType.TEST,
                               with_excludes=False,
                               encoding=False)
    provider = provider_cls(raw_data_provider)

    dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100)
    encountered_labels = set()
    for batch in dataset:
        left_label = batch[0][
            consts.LEFT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5
        right_label = batch[0][
            consts.RIGHT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5
        encountered_labels.update((left_label, right_label))
    assert_that((np.rint(list(encountered_labels)) * 10),
                only_contains(not_(is_in(list(patched_excluded.numpy())))))
    assert_that((np.rint(list(encountered_labels)) * 10),
                only_contains((is_in([0, 1, 4]))))
コード例 #13
0
def test_eval_with_excludes_input_fn_should_search_for_dataset_with_correct_spec(
        patched_dataset_building):
    default_provider.eval_with_excludes_input_fn()
    patched_dataset_building.assert_called_once_with(
        DatasetSpec(raw_data_provider, DatasetType.TEST, with_excludes=True))
コード例 #14
0
def test_train_input_fn_should_search_for_dataset_with_correct_spec(
        patched_dataset_building):
    default_provider.train_input_fn()
    patched_dataset_building.assert_called_once_with(
        DatasetSpec(raw_data_provider, DatasetType.TRAIN, with_excludes=False))
コード例 #15
0
ファイル: conftest.py プロジェクト: arozans/idenface
@pytest.fixture
def number_translation_features_and_labels(number_translation_features_dict):
    feature_and_label_pairs = []
    for key, value in number_translation_features_dict.items():
        for elem in value:
            feature_and_label_pairs.append((key, elem))
    labels, features = zip(*feature_and_label_pairs)
    return list(features), list(labels)


class NumberTranslationRawDataProvider(AbstractRawDataProvider):
    # noinspection PyTypeChecker
    @property
    def description(self) -> DataDescription:
        return DataDescription(TestDatasetVariant.NUMBERTRANSLATION, None, 3)

    def get_raw_train(self) -> Tuple[np.ndarray, np.ndarray]:
        return number_translation_features_and_labels(
            number_translation_features_dict())

    def get_raw_test(self) -> Tuple[np.ndarray, np.ndarray]:
        return number_translation_features_and_labels(
            number_translation_features_dict())


TRANSLATIONS_TRAIN_DATASET_SPEC = DatasetSpec(
    raw_data_provider=NumberTranslationRawDataProvider(),
    type=DatasetType.TRAIN,
    with_excludes=False)
コード例 #16
0
    def _dataset_provider_cls(self) -> Type[AbstractDatasetProvider]:
        return FromGeneratorDatasetProvider

    @property
    def summary(self) -> str:
        return self.name + '_generated_dataset'


class MnistSoftmaxModelWithTfRecordDataset(MnistSoftmaxModel):
    @property
    def summary(self) -> str:
        return self.name + '_TFRecord_dataset'


MNIST_TRAIN_DATASET_SPEC_IGNORING_EXCLUDES = DatasetSpec(
    raw_data_provider=MnistRawDataProvider(),
    type=DatasetType.TRAIN,
    with_excludes=True)
MNIST_TRAIN_DATASET_SPEC = DatasetSpec(
    raw_data_provider=MnistRawDataProvider(),
    type=DatasetType.TRAIN,
    with_excludes=False)
MNIST_TEST_DATASET_SPEC = DatasetSpec(raw_data_provider=MnistRawDataProvider(),
                                      type=DatasetType.TEST,
                                      with_excludes=False)
MNIST_TEST_DATASET_SPEC_IGNORING_EXCLUDES = DatasetSpec(
    raw_data_provider=MnistRawDataProvider(),
    type=DatasetType.TEST,
    with_excludes=True)
FAKE_TRAIN_DATASET_SPEC = DatasetSpec(raw_data_provider=FakeRawDataProvider(),
                                      type=DatasetType.TRAIN,
                                      with_excludes=False)
コード例 #17
0
 def __init__(self, dataset_spec: DatasetSpec):
     super().__init__(dataset_spec)
     self.resizing = dataset_spec.should_resize_raw_data()