def test_training(self, simple_model, loss, optimizer, simple_data): """ Tests that a single training epoch can succeed. """ simple_model.parse(None) simple_model.build() trainer = Executor(model=simple_model, loss=loss, optimizer=optimizer) trainer.train(provider=simple_data, stop_when={'epochs': 1})
def test_ctc_train(self, ctc_model, ctc_data, ctc_loss, optimizer): """ Tests that we can compile and train a model using the CTC loss function. """ ctc_model.parse(None) ctc_model.register_provider(ctc_data) ctc_model.build() trainer = Executor(model=ctc_model, loss=ctc_loss, optimizer=optimizer) trainer.train(provider=ctc_data, stop_when={'epochs': 1})
def test_uber_train(self, uber_model, uber_data, jinja_engine, loss, optimizer): """ Tests that we can compile and train a diverse model. """ uber_model.parse(jinja_engine) uber_model.register_provider(uber_data) uber_model.build() trainer = Executor(model=uber_model, loss=loss, optimizer=optimizer) trainer.compile() trainer.train(provider=uber_data, stop_when={'epochs': 1})
def test_embedding_train(self, embedding_model, embedding_data, loss, optimizer): """ Tests that we can compile and train a model which has an Embedding. """ embedding_model.parse(None) embedding_model.register_provider(embedding_data) embedding_model.build() trainer = Executor(model=embedding_model, loss=loss, optimizer=optimizer) trainer.train(provider=embedding_data, stop_when={'epochs': 1})
def test_training(self, simple_model, loss, optimizer, simple_data): """ Tests that a single training epoch can succeed. """ simple_model.parse(None) simple_model.build() trainer = Executor( model=simple_model, loss=loss, optimizer=optimizer ) trainer.train(provider=simple_data, stop_when={'epochs' : 1})
def test_embedding_train(self, embedding_model, embedding_data, loss, optimizer): """ Tests that we can compile and train a model which has an Embedding. """ embedding_model.parse(None) embedding_model.register_provider(embedding_data) embedding_model.build() trainer = Executor( model=embedding_model, loss=loss, optimizer=optimizer ) trainer.train(provider=embedding_data, stop_when={'epochs' : 1})
def test_ctc_train(self, ctc_model, ctc_data, ctc_loss, optimizer): """ Tests that we can compile and train a model using the CTC loss function. """ ctc_model.parse(None) ctc_model.register_provider(ctc_data) ctc_model.build() trainer = Executor( model=ctc_model, loss=ctc_loss, optimizer=optimizer ) trainer.train(provider=ctc_data, stop_when={'epochs' : 1})
def test_uber_train(self, uber_model, uber_data, jinja_engine, loss, optimizer): """ Tests that we can compile and train a diverse model. """ if uber_model.get_backend().get_name() == 'keras' and \ uber_model.get_backend().keras_version() == 2 and \ uber_model.get_backend().get_toolchain() == 'tensorflow' and \ sys.version_info < (3, 5): pytest.skip('Occassional SIGSEGV') uber_model.parse(jinja_engine) uber_model.register_provider(uber_data) uber_model.build() trainer = Executor(model=uber_model, loss=loss, optimizer=optimizer) trainer.compile() trainer.train(provider=uber_data, stop_when={'epochs': 1})
def test_uber_train(self, uber_model, uber_data, jinja_engine, loss, optimizer): """ Tests that we can compile and train a diverse model. """ if uber_model.get_backend().get_name() == 'keras' and \ uber_model.get_backend().keras_version() == 2 and \ uber_model.get_backend().get_toolchain() == 'tensorflow' and \ sys.version_info < (3, 5): pytest.skip('Occassional SIGSEGV') uber_model.parse(jinja_engine) uber_model.register_provider(uber_data) uber_model.build() trainer = Executor( model=uber_model, loss=loss, optimizer=optimizer ) trainer.compile() trainer.train(provider=uber_data, stop_when={'epochs' : 1})