def predict(self, data): kurfile = Kurfile(self.get_model_path(), JinjaEngine()) kurfile.parse() model = kurfile.get_model() with DisableLogging(logging.WARNING): model.backend.compile(model) model.restore(self.get_path('weights')) pdf, metrics = model.backend.evaluate(model, data={'in': np.array([data])}) prediction = pdf['out'][0][0] return prediction
def load(): spec_file = 'speech.yml' w_file = 'weights' spec = Kurfile(spec_file, JinjaEngine()) spec.parse() model = spec.get_model() model.backend.compile(model) model.restore(w_file) norm = Normalize(center=True, scale=True, rotate=True) norm.restore('norm.yml') trans = TranscriptHook() rev = {0: ' ', 1: "'", 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: 'i', 11: 'j', 12: 'k', 13: 'l', 14: 'm', 15: 'n', 16: 'o', 17: 'p', 18: 'q', 19: 'r', 20: 's', 21: 't', 22: 'u', 23: 'v', 24: 'w', 25: 'x', 26: 'y', 27: 'z'} blank = 28 return model, norm, trans, rev, blank
def passthrough_engine(): """ Returns a Jinja2 engine. """ return JinjaEngine()
def jinja_engine(): """ Returns a Jinja2 engine. """ return JinjaEngine()