def test_all_unpaired_dataset_providers_should_honor_excludes( dataset_provider_cls_name, patched_excluded): provider_cls = from_class_name(dataset_provider_cls_name) raw_data_provider = FakeRawDataProvider(curated=True) dataset_spec = DatasetSpec(raw_data_provider, DatasetType.TRAIN, with_excludes=False, encoding=False, paired=False) provider = provider_cls(raw_data_provider) dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100) encountered_labels = set() for batch in dataset: image = np.rint( (batch[0][consts.FEATURES].numpy().flatten()[0] + 0.5) * 10) label = batch[1][consts.LABELS].numpy().flatten()[0] encountered_labels.update((image, )) encountered_labels.update((label, )) assert_that((np.rint(list(encountered_labels))), only_contains(not_(is_in(list(patched_excluded.numpy()))))) assert_that((np.rint(list(encountered_labels))), only_contains((is_in([0, 1, 4]))))
def test_all_paired_dataset_providers_should_provide_correct_labels( dataset_provider_cls_name): provider_cls = from_class_name(dataset_provider_cls_name) raw_data_provider = FakeRawDataProvider(curated=True) dataset_spec = DatasetSpec(raw_data_provider, DatasetType.TEST, with_excludes=False, encoding=False) provider = provider_cls(raw_data_provider) dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100) for batch in dataset: left_img = np.rint( (batch[0][consts.LEFT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5) * 10) right_img = np.rint( (batch[0][consts.RIGHT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5) * 10) pair_label = batch[1][consts.TFRECORD_PAIR_LABEL].numpy().flatten()[0] left_label = batch[1][consts.TFRECORD_LEFT_LABEL].numpy().flatten()[0] right_label = batch[1][ consts.TFRECORD_RIGHT_LABEL].numpy().flatten()[0] assert pair_label == (1 if left_img == right_img else 0) assert pair_label == (1 if left_label == right_label else 0)
def eval_with_excludes_input_fn(self) -> Dataset: utils.log('Creating eval_input_fn with excluded elements') test_ignoring_excludes = DatasetSpec( raw_data_provider=self.raw_data_provider, type=DatasetType.TEST, with_excludes=True, encoding=self.is_encoded()) return self.supply_dataset(dataset_spec=test_ignoring_excludes, batch_size=self.calculate_batch_size())
def eval_input_fn(self) -> Dataset: utils.log('Creating eval_input_fn') test_data_config = DatasetSpec( raw_data_provider=self.raw_data_provider, type=DatasetType.TEST, with_excludes=False, encoding=self.is_encoded()) return self.supply_dataset(dataset_spec=test_data_config, batch_size=self.calculate_batch_size())
def train_input_fn(self) -> Dataset: utils.log('Creating train_input_fn') train_data_config = DatasetSpec( raw_data_provider=self.raw_data_provider, type=DatasetType.TRAIN, with_excludes=False, encoding=self.is_encoded(), paired=self.is_train_paired()) return self.supply_dataset( dataset_spec=train_data_config, shuffle_buffer_size=config[consts.SHUFFLE_BUFFER_SIZE], batch_size=config[consts.BATCH_SIZE], repeat=True)
def infer(self, take_num: int) -> Dataset: utils.log('Creating infer_fn') test_with_excludes = DatasetSpec( raw_data_provider=self.raw_data_provider, type=DatasetType.TEST, with_excludes=True, encoding=self.is_encoded()) return self.supply_dataset( dataset_spec=test_with_excludes, batch_size=take_num, repeat=False, shuffle_buffer_size=config[consts.SHUFFLE_BUFFER_SIZE], prefetch=False, take_num=take_num)
def test_all_unpaired_dataset_providers_should_provide_correct_labels( dataset_provider_cls_name): provider_cls = from_class_name(dataset_provider_cls_name) raw_data_provider = FakeRawDataProvider(curated=True) dataset_spec = DatasetSpec(raw_data_provider, DatasetType.TRAIN, with_excludes=False, encoding=False, paired=False) provider = provider_cls(raw_data_provider) dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100) for batch in dataset: images = np.rint( (batch[0][consts.FEATURES].numpy().flatten()[0] + 0.5) * 10) labels = batch[1][consts.LABELS].numpy().flatten()[0] assert images == labels
def dataset_spec( description: DataDescription = None, raw_dataset_fragment: RawDatasetFragment = None, type=DatasetType.TRAIN, with_excludes=False, encoding=True, paired=True, repeating_pairs=True, identical_pairs=False, ): from testing_utils.testing_classes import FakeRawDataProvider return DatasetSpec(raw_data_provider=FakeRawDataProvider( description, raw_dataset_fragment, curated=True), type=type, with_excludes=with_excludes, encoding=encoding, paired=paired, repeating_pairs=repeating_pairs, identical_pairs=identical_pairs)
def test_train_input_fn_should_correct_configure_dataset( mocker, dataset_provider, paired): patched_dataset_supplying = mocker.patch.object(dataset_provider, 'supply_dataset', autospec=True) dataset_provider.train_input_fn() patched_dataset_supplying.assert_called_once_with( DatasetSpec(raw_data_provider, DatasetType.TRAIN, with_excludes=False, encoding=True, repeating_pairs=True, identical_pairs=False, paired=paired), shuffle_buffer_size=config[consts.SHUFFLE_BUFFER_SIZE], batch_size=config[consts.BATCH_SIZE], repeat=True)
def test_test_eval_with_excludes_input_fn_should_correct_configure_dataset( mocker, dataset_provider): patched_dataset_supplying = mocker.patch.object(dataset_provider, 'supply_dataset', autospec=True) dataset_provider.eval_with_excludes_input_fn() if dataset_provider.is_train_paired(): expected_batch_size = config[consts.BATCH_SIZE] else: expected_batch_size = config[consts.BATCH_SIZE] / 2 patched_dataset_supplying.assert_called_once_with( DatasetSpec(raw_data_provider, DatasetType.TEST, with_excludes=True, encoding=True, repeating_pairs=True, identical_pairs=False, paired=True), batch_size=expected_batch_size)
def create_pair_summaries(run_data: RunData): dataset_provider_cls = run_data.model.raw_data_provider tf.reset_default_graph() batch_size = 10 utils.log('Creating {} sample features summaries'.format(batch_size)) dataset: tf.data.Dataset = run_data.model.dataset_provider.supply_dataset( dataset_spec=DatasetSpec( dataset_provider_cls, DatasetType.TEST, with_excludes=False, encoding=run_data.model.dataset_provider.is_encoded()), shuffle_buffer_size=10000, batch_size=batch_size, prefetch=False) iterator = dataset.make_one_shot_iterator() iterator = iterator.get_next() with tf.Session() as sess: left = iterator[0][consts.LEFT_FEATURE_IMAGE] right = iterator[0][consts.RIGHT_FEATURE_IMAGE] pair_labels = iterator[1][consts.PAIR_LABEL] left_labels = iterator[1][consts.LEFT_FEATURE_LABEL] right_labels = iterator[1][consts.RIGHT_FEATURE_LABEL] pairs_imgs_summary = create_pair_summary( left, right, pair_labels, left_labels, right_labels, dataset_provider_cls.description) image_summary = tf.summary.image('paired_images', pairs_imgs_summary, max_outputs=batch_size) all_summaries = tf.summary.merge_all() dir = filenames.get_run_logs_data_dir(run_data) / 'features' dir.mkdir(exist_ok=True, parents=True) writer = tf.summary.FileWriter(str(dir), sess.graph) sess.run(tf.global_variables_initializer()) summary = sess.run(all_summaries) writer.add_summary(summary) writer.flush()
def test_all_paired_dataset_providers_should_honor_excludes( dataset_provider_cls_name, patched_excluded): provider_cls = from_class_name(dataset_provider_cls_name) raw_data_provider = FakeRawDataProvider(curated=True) dataset_spec = DatasetSpec(raw_data_provider, DatasetType.TEST, with_excludes=False, encoding=False) provider = provider_cls(raw_data_provider) dataset = provider.supply_dataset(dataset_spec, batch_size=1).take(100) encountered_labels = set() for batch in dataset: left_label = batch[0][ consts.LEFT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5 right_label = batch[0][ consts.RIGHT_FEATURE_IMAGE].numpy().flatten()[0] + 0.5 encountered_labels.update((left_label, right_label)) assert_that((np.rint(list(encountered_labels)) * 10), only_contains(not_(is_in(list(patched_excluded.numpy()))))) assert_that((np.rint(list(encountered_labels)) * 10), only_contains((is_in([0, 1, 4]))))
def test_eval_with_excludes_input_fn_should_search_for_dataset_with_correct_spec( patched_dataset_building): default_provider.eval_with_excludes_input_fn() patched_dataset_building.assert_called_once_with( DatasetSpec(raw_data_provider, DatasetType.TEST, with_excludes=True))
def test_train_input_fn_should_search_for_dataset_with_correct_spec( patched_dataset_building): default_provider.train_input_fn() patched_dataset_building.assert_called_once_with( DatasetSpec(raw_data_provider, DatasetType.TRAIN, with_excludes=False))
@pytest.fixture def number_translation_features_and_labels(number_translation_features_dict): feature_and_label_pairs = [] for key, value in number_translation_features_dict.items(): for elem in value: feature_and_label_pairs.append((key, elem)) labels, features = zip(*feature_and_label_pairs) return list(features), list(labels) class NumberTranslationRawDataProvider(AbstractRawDataProvider): # noinspection PyTypeChecker @property def description(self) -> DataDescription: return DataDescription(TestDatasetVariant.NUMBERTRANSLATION, None, 3) def get_raw_train(self) -> Tuple[np.ndarray, np.ndarray]: return number_translation_features_and_labels( number_translation_features_dict()) def get_raw_test(self) -> Tuple[np.ndarray, np.ndarray]: return number_translation_features_and_labels( number_translation_features_dict()) TRANSLATIONS_TRAIN_DATASET_SPEC = DatasetSpec( raw_data_provider=NumberTranslationRawDataProvider(), type=DatasetType.TRAIN, with_excludes=False)
def _dataset_provider_cls(self) -> Type[AbstractDatasetProvider]: return FromGeneratorDatasetProvider @property def summary(self) -> str: return self.name + '_generated_dataset' class MnistSoftmaxModelWithTfRecordDataset(MnistSoftmaxModel): @property def summary(self) -> str: return self.name + '_TFRecord_dataset' MNIST_TRAIN_DATASET_SPEC_IGNORING_EXCLUDES = DatasetSpec( raw_data_provider=MnistRawDataProvider(), type=DatasetType.TRAIN, with_excludes=True) MNIST_TRAIN_DATASET_SPEC = DatasetSpec( raw_data_provider=MnistRawDataProvider(), type=DatasetType.TRAIN, with_excludes=False) MNIST_TEST_DATASET_SPEC = DatasetSpec(raw_data_provider=MnistRawDataProvider(), type=DatasetType.TEST, with_excludes=False) MNIST_TEST_DATASET_SPEC_IGNORING_EXCLUDES = DatasetSpec( raw_data_provider=MnistRawDataProvider(), type=DatasetType.TEST, with_excludes=True) FAKE_TRAIN_DATASET_SPEC = DatasetSpec(raw_data_provider=FakeRawDataProvider(), type=DatasetType.TRAIN, with_excludes=False)
def __init__(self, dataset_spec: DatasetSpec): super().__init__(dataset_spec) self.resizing = dataset_spec.should_resize_raw_data()