Beispiel #1
0
    def testGeneratorSingleInputFn(self):
        def generator():
            for index in range(2):
                yield {'a': np.ones(1) * index}

        with self.cached_session() as session:
            input_fn = generator_io.generator_input_fn(generator,
                                                       target_key=None,
                                                       batch_size=2,
                                                       shuffle=False,
                                                       num_epochs=1)
            features = input_fn()

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            res = session.run([features])
            self.assertAllEqual(res[0]['a'], np.asarray([0, 1]).reshape(-1, 1))

            session.run([features])
            with self.assertRaises(errors.OutOfRangeError):
                session.run([features])

            coord.request_stop()
            coord.join(threads)
Beispiel #2
0
    def testGeneratorInputFnWithBatchLargerthanData(self):
        def generator():
            for index in range(2):
                yield {
                    'a': np.ones(1) * index,
                    'b': np.ones(1) * index + 32,
                    'label': np.ones(1) * index - 32
                }

        with self.cached_session() as session:
            input_fn = generator_io.generator_input_fn(generator,
                                                       target_key=None,
                                                       batch_size=4,
                                                       shuffle=False,
                                                       num_epochs=1)
            features = input_fn()

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            res = session.run(features)
            self.assertAllEqual(res['a'],
                                np.asarray([0, 1, 0, 1]).reshape(-1, 1))
            self.assertAllEqual(res['b'],
                                np.asarray([32, 33, 32, 33]).reshape(-1, 1))
            self.assertAllEqual(
                res['label'],
                np.asarray([-32, -31, -32, -31]).reshape(-1, 1))

            with self.assertRaises(errors.OutOfRangeError):
                session.run([features])

            coord.request_stop()
            coord.join(threads)
Beispiel #3
0
    def testGeneratorInputFnWithMismatchinGeneratorKeys(self):
        def generator():
            index = 0
            yield {
                'a': np.ones(1) * index,
                'b': np.ones(1) * index + 32,
                'label': np.ones(1) * index - 32
            }
            index = 1
            yield {
                'a': np.ones(1) * index,
                'c': np.ones(1) * index + 32,
                'label': np.ones(1) * index - 32
            }

        with self.cached_session() as session:
            input_fn = generator_io.generator_input_fn(generator,
                                                       target_key=None,
                                                       batch_size=2,
                                                       shuffle=False,
                                                       num_epochs=1)
            features = input_fn()

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            with self.assertRaises(errors.OutOfRangeError):
                session.run([features])

            with self.assertRaisesRegex(
                    KeyError, 'key mismatch between dicts emitted'
                    ' by GenFunExpected'):
                coord.request_stop()
                coord.join(threads)
Beispiel #4
0
 def testGeneratorInputFnWithXAsNonGeneratorFunction(self):
     x = np.arange(32, 36)
     with self.cached_session():
         with self.assertRaisesRegexp(TypeError,
                                      'x must be generator function'):
             failing_input_fn = generator_io.generator_input_fn(
                 x, batch_size=2, shuffle=False, num_epochs=1)
             failing_input_fn()
Beispiel #5
0
    def testGeneratorInputFnWithXAsNonGeneratorYieldingDicts(self):
        def generator():
            yield np.arange(32, 36)

        with self.cached_session():
            with self.assertRaisesRegexp(TypeError, r'x\(\) must yield dict'):
                failing_input_fn = generator_io.generator_input_fn(
                    generator, batch_size=2, shuffle=False, num_epochs=1)
                failing_input_fn()
Beispiel #6
0
    def testGeneratorInputFNWithTargetLabelNotInDict(self):
        def generator():
            for index in range(2):
                yield {
                    'a': np.ones((10, 10)) * index,
                    'b': np.ones((5, 5)) * index + 32,
                    'label': np.ones((3, 3)) * index - 32
                }

        y = ['label', 'target']
        with self.cached_session():
            with self.assertRaisesRegexp(KeyError,
                                         'target_key not in yielded dict'):
                failing_input_fn = generator_io.generator_input_fn(
                    generator,
                    target_key=y,
                    batch_size=2,
                    shuffle=False,
                    num_epochs=1)
                failing_input_fn()
Beispiel #7
0
    def testGeneratorInputFNWithTargetLabelListNotString(self):
        def generator():
            for index in range(2):
                yield {
                    'a': np.ones((10, 10)) * index,
                    'b': np.ones((5, 5)) * index + 32,
                    'label': np.ones((3, 3)) * index - 32
                }

        y = ['label', np.arange(10)]
        with self.cached_session():
            with self.assertRaisesRegexp(
                    TypeError, 'target_key must be str or'
                    ' Container of str'):
                failing_input_fn = generator_io.generator_input_fn(
                    generator,
                    target_key=y,
                    batch_size=2,
                    shuffle=False,
                    num_epochs=1)
                failing_input_fn()
Beispiel #8
0
    def testGeneratorInputFnWithDifferentDimensionsOfFeatures(self):
        def generator():
            for index in range(100):
                yield {
                    'a': np.ones((10, 10)) * index,
                    'b': np.ones((5, 5)) * index + 32,
                    'label': np.ones((3, 3)) * index - 32
                }

        with self.cached_session() as session:
            input_fn = generator_io.generator_input_fn(generator,
                                                       target_key='label',
                                                       batch_size=2,
                                                       shuffle=False,
                                                       num_epochs=1)
            features, target = input_fn()

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            res = session.run([features, target])
            self.assertAllEqual(
                res[0]['a'],
                np.vstack((np.zeros((10, 10)), np.ones(
                    (10, 10)))).reshape(2, 10, 10))
            self.assertAllEqual(
                res[0]['b'],
                np.vstack((np.zeros((5, 5)), np.ones(
                    (5, 5)))).reshape(2, 5, 5) + 32)
            self.assertAllEqual(
                res[1],
                np.vstack((np.zeros((3, 3)), np.ones(
                    (3, 3)))).reshape(2, 3, 3) - 32)

            coord.request_stop()
            coord.join(threads)