Esempio n. 1
0
    def setUp(self):
        super(DataAdapterTestBase, self).setUp()
        self.batch_size = 5
        self.numpy_input = np.zeros((50, 10))
        self.numpy_target = np.ones(50)
        self.tensor_input = tf.constant(2.0, shape=(50, 10))
        self.tensor_target = tf.ones((50, ))
        self.arraylike_input = DummyArrayLike(self.numpy_input)
        self.arraylike_target = DummyArrayLike(self.numpy_target)
        self.dataset_input = tf.data.Dataset.from_tensor_slices(
            (self.numpy_input,
             self.numpy_target)).shuffle(50).batch(self.batch_size)

        def generator():
            while True:
                yield (np.zeros(
                    (self.batch_size, 10)), np.ones(self.batch_size))

        self.generator_input = generator()
        self.iterator_input = data_utils.threadsafe_generator(generator)()
        self.sequence_input = TestSequence(batch_size=self.batch_size,
                                           feature_shape=10)
        self.text_input = [['abc']]
        self.bytes_input = [[b'abc']]
        self.model = keras.models.Sequential(
            [keras.layers.Dense(8, input_shape=(10, ), activation='softmax')])
Esempio n. 2
0
        batch_index = i * batch_size % num_samples
        i += 1
        start = batch_index
        end = start + cur_batch_size
        x = arr_data[start:end]
        y = arr_labels[start:end]
        w = arr_weights[start:end]
        if mode == 1:
            yield x
        elif mode == 2:
            yield x, y
        else:
            yield x, y, w


custom_generator_threads = data_utils.threadsafe_generator(custom_generator)


class TestGeneratorMethods(keras_parameterized.TestCase):
    @keras_parameterized.run_with_all_model_types
    @keras_parameterized.run_all_keras_modes
    @data_utils.dont_use_multiprocessing_pool
    def test_fit_generator_method(self):
        model = testing_utils.get_small_mlp(num_hidden=3,
                                            num_classes=4,
                                            input_dim=2)
        model.compile(loss='mse',
                      optimizer=rmsprop.RMSprop(1e-3),
                      metrics=['mae',
                               metrics_module.CategoricalAccuracy()])