Beispiel #1
0
            def fit_generator(wrapper, *args: Any, **kwargs: Any) -> None:
                if not self.compile_args:
                    raise errors.InvalidExperimentException(
                        "Must call .compile before calling .fit_generator().")

                fit_generator_args = inspect.signature(
                    model.fit_generator).bind(*args, **kwargs)
                fit_generator_args.apply_defaults()

                training_data = keras.SequenceAdapter(
                    fit_generator_args.arguments["generator"],
                    use_multiprocessing=fit_generator_args.
                    arguments["use_multiprocessing"],
                    workers=fit_generator_args.arguments["workers"],
                )
                validation_data = keras.SequenceAdapter(
                    fit_generator_args.arguments["validation_data"],
                    use_multiprocessing=fit_generator_args.
                    arguments["use_multiprocessing"],
                    workers=fit_generator_args.arguments["workers"],
                )

                self.train_config = TFKerasTrainConfig(
                    training_data=training_data,
                    validation_data=validation_data,
                    callbacks=fit_generator_args.arguments["callbacks"],
                )

                if train_fn:
                    train_fn()
Beispiel #2
0
def test_sequence_adapter(workers: int, use_multiprocessing: bool,
                          seq: Sequence) -> None:
    data = keras.SequenceAdapter(seq,
                                 workers=workers,
                                 use_multiprocessing=use_multiprocessing)
    assert len(data) == len(seq)

    data.start()
    iterator = data.get_iterator()
    assert iterator is not None

    for i in range(len(seq)):
        a = seq[i]
        b = next(iterator)
        assert len(a) == len(b)
        for i in range(len(a)):
            assert np.equal(a[i], b[i]).all()
    data.stop()
 def build_validation_data_loader(self) -> keras.InputData:
     _, test = make_xor_data_sequences(batch_size=4)
     return keras.SequenceAdapter(test, workers=0)
 def build_training_data_loader(self) -> keras.InputData:
     train, _ = make_xor_data_sequences(batch_size=4)
     return keras.SequenceAdapter(train, workers=0)