def test_fit_generator_with_validation(self): # declare a batch generator train, test = train_test_split(self.data, test_size=.2) train_bg = BatchGenerator(train, x_structure=('x1', None), y_structure=('y1', self.y1_enc)) test_bg = BatchGenerator(test, x_structure=('x1', None), y_structure=('y1', self.y1_enc)) inp = keras.layers.Input(shape=(1, )) x = keras.layers.Dense(4, activation='softmax')(inp) keras_model = keras.models.Model(inp, x) keras_model.compile('adam', 'sparse_categorical_crossentropy') keras_model.fit_generator(train_bg, epochs=1, validation_data=test_bg)
def test_mro(self): mro = BatchGenerator.mro() i1 = [i for i, x in enumerate(mro) if x == BaseBatchGenerator] i2 = [i for i, x in enumerate(mro) if x == Sequence] assert len(i1) == 1 assert len(i2) == 1 # check that BaseBatch generator is resolved before SequenceTF assert i1 < i2
def test_predict_without_y(self): keras_bg = BatchGenerator(self.data, x_structure=('x1', None)) inp = keras.layers.Input(shape=(1, )) x = keras.layers.Dense(4, activation='softmax')(inp) keras_model = keras.models.Model(inp, x) keras_model.compile('adam', 'sparse_categorical_crossentropy') pred = keras_model.predict_generator(keras_bg, verbose=1) assert type(pred) == np.ndarray
def test_fit_generator_multiple_epochs(self): # declare a batch generator keras_bg = BatchGenerator(self.data, x_structure=('x1', None), y_structure=('y1', self.y1_enc)) inp = keras.layers.Input(shape=(1, )) x = keras.layers.Dense(4, activation='softmax')(inp) keras_model = keras.models.Model(inp, x) keras_model.compile('adam', 'sparse_categorical_crossentropy') keras_model.fit_generator(keras_bg, epochs=2)
def test_multiple_inputs(self): keras_bg = BatchGenerator(self.data, x_structure=[('x1', None), ('x2', None)], y_structure=('y1', self.y1_enc)) inp1 = keras.layers.Input(shape=(1, )) inp2 = keras.layers.Input(shape=(1, )) x = keras.layers.Concatenate()([inp1, inp2]) x = keras.layers.Dense(4, activation='softmax')(x) keras_model = keras.models.Model([inp1, inp2], x) keras_model.compile('adam', 'sparse_categorical_crossentropy') keras_model.fit_generator(keras_bg, epochs=2)