def test_e2e_export_and_query(self): """Test that we can export and query the model via tf.serving.""" FLAGS.t2t_usr_dir = _get_t2t_usr_dir() FLAGS.problem = "github_function_docstring" FLAGS.data_dir = "/mnt/nfs-east1-d/data" FLAGS.tmp_dir = "/mnt/nfs-east1-d/tmp" FLAGS.output_dir = tempfile.mkdtemp() #FLAGS.export_dir = os.path.join(FLAGS.output_dir, "export") FLAGS.model = "similarity_transformer_dev" FLAGS.hparams_set = "similarity_transformer_tiny" FLAGS.train_steps = 1 FLAGS.schedule = "train" timeout_secs = 10 usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) t2t_trainer.main(None) export.main(None) # ---- # Start model server # Will start a tf model server on an un-used port and # kill process on exit. _, server, _ = TensorflowModelServer().RunServer( FLAGS.model, FLAGS.output_dir) # ---- # Query the server return doc_query = [1, 2, 3] # Dummy encoded doc query code_query = [1, 2, 3] # Dummy encoded code query # Alternatively for query, without going through query.main() # TODO: Is servable_name the same as model name? request_fn = serving_utils.make_grpc_request_fn( servable_name=FLAGS.model, server=server, timeout_secs=timeout_secs) # Compute embeddings # TODO: May need to customize how these queries are fed in, potentially # side-stepping serving_utils.predict. encoded_string = serving_utils.predict([doc_query], problem_object, request_fn) encoded_code = serving_utils.predict([code_query], problem_object, request_fn)
def _do_send_request(self, text_arr): """ Divide the arr into batches and send the batches to the backend to be processed :param text_arr: individual elements of arr will be grouped into batches :return: """ outputs = [] request_fn = serving_utils.make_grpc_request_fn( servable_name=self.model, timeout_secs=500, server=self.server) for batch in np.array_split(text_arr, ceil(len(text_arr) / self.batch_size)): try: models.log.debug(f"===== sending batch\n{pformat(batch)}\n") outputs += list( map( lambda sent_score: sent_score[0], serving_utils.predict(batch.tolist(), self.problem, request_fn))) except: # When tensorflow serving restarts web clients seem to "remember" the channel where # the connection have failed. clearing up the session, seems to solve that session.clear() raise return outputs
def query(self, sentences): """ :param sentences: str :return: """ tmp = [] tokens = 0 for sentence in self.cut_sent(sentences): sentence = self.mecab.parse(sentence.strip()).split() tokens += len(sentence) tmp.append(" ".join(sentence)) inputs = tmp del tmp outputs = [] print(inputs) start = time.time() for i in range(0, len(inputs), FLAGS.batch): batch_output = serving_utils.predict(inputs[i:(i+FLAGS.batch)], self.problem, self.request_fn) batch_output = [self.detokenizer(output[0].split(" ")) for output in batch_output] outputs.extend([{"key": zh, "value": en} for zh, en in zip(inputs[i:(i+FLAGS.batch)], batch_output)]) end = time.time() printstr = "Sentences: {sentence:d}\tTokens: {tokens:d}" \ "\tTime: {time:.3f}ms\tTokens/time: {per:.3f}ms" logging.info(printstr.format(sentence=len(inputs), tokens=tokens, time=(end - start) * 1000, per=((end - start) * 1000 / tokens))) for output in outputs: logging.info("Input:{input:s}".format(input=output["key"])) logging.info("Output:{output:s}".format(output=output["value"])) return outputs
def face(): replacements = [] rules = [] data = request.json confident = 0 for txt in data["text"]: text = txt["text"] offset = txt["offset"] if unidecode.unidecode(text) == text: continue outputs = serving_utils.predict([text], problem, request_fn) outputs, = outputs output, confident = outputs rule = AuRule("MORFOLOGIK_RULE_EN_US", "Possible spelling mistake", "misspelling", None) replacement = AuReplacement(text, output, offset, len(text), None) rules.append(rule) replacements.append(replacement) #matches.append(match) matches = AuMatch("Tự động thêm dấu", "Tự động thêm dấu", replacements, rules) ret = AuResult(AuLanguage("Vietnamese (Vi)", "vi-VN"), AuLanguage("Vietnamese (Vi)", "vi-VN"), matches, '%f' % confident) return ret.toJSON()
def main(_): tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) if FLAGS.cloud_mlengine_model_name: request_fn = serving_utils.make_cloud_mlengine_request_fn( credentials=GoogleCredentials.get_application_default(), model_name=FLAGS.cloud_mlengine_model_name, version=FLAGS.cloud_mlengine_model_version) else: request_fn = serving_utils.make_grpc_request_fn( servable_name=FLAGS.servable_name, server=FLAGS.server, timeout_secs=FLAGS.timeout_secs) while True: inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ") outputs = serving_utils.predict([inputs], problem, request_fn) print_str = """ Input: {inputs} Output: {outputs} """ print(print_str.format(inputs=inputs, outputs=outputs[0])) if FLAGS.inputs_once: break
def main(_): tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) request_fn = make_request_fn() while True: inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ") inputs = preprocess(inputs) if FLAGS.word_cut: inputs = " ".join(jieba.cut(inputs)) outputs = serving_utils.predict([inputs], problem, request_fn) outputs, = outputs output, score = outputs output = postprocess(output) print_str = """ Input: {inputs} Output (Score {score:.3f}): {output} """ print(print_str.format(inputs=inputs, output=output, score=score)) if FLAGS.inputs_once: break
def summary(): """Main prediction route. Provides a machine-generated summary of the given text. Sends a request to a live model trained on GitHub issues. """ global problem #pylint: disable=global-statement if problem is None: init() request_fn = make_tfserving_rest_request_fn() if request.method == 'POST': issue_text = request.form["issue_text"] issue_url = request.form["issue_url"] if issue_url: print("fetching issue from URL...") issue_text = get_issue_body(issue_url) tf.logging.info("issue_text: %s", issue_text) outputs = serving_utils.predict([issue_text], problem, request_fn) outputs, = outputs output, score = outputs #pylint: disable=unused-variable tf.logging.info("output: %s", output) return jsonify({'summary': output, 'body': issue_text}) return ('', 204)
def summary(): """Main prediction route. Provides a machine-generated summary of the given text. Sends a request to a live model trained on GitHub issues. """ global problem if problem is None: init() request_fn = make_tfserving_rest_request_fn() if request.method == 'POST': issue_text = request.form["issue_text"] issue_url = request.form["issue_url"] if issue_url: print("fetching issue from URL...") issue_text = get_issue_body(issue_url) tf.logging.info("issue_text: %s", issue_text) outputs = serving_utils.predict([issue_text], problem, request_fn) outputs, = outputs output, score = outputs #pylint: disable=unused-variable tf.logging.info("output: %s", output) return jsonify({'summary': output, 'body': issue_text}) return ('', 204)
def convert_file(file): problem = registry.problem(FLAGS.problem) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) if os.path.isfile(T2T_Model_Path+"/4b_zh-tokenized-sample-en/"+file): print(file+" exists already") return None else: print(file) with codecs.open(T2T_Model_Path+"/4b_zh-tokenized-sample-en/"+file, mode='w+', encoding='utf-8') as new_file: with codecs.open("./"+file, mode='r', encoding='utf-8') as lines: for inputs in lines: try: print(inputs) inputs = ftfy.fix_text(inputs.replace('\n','')).encode('utf-8') output = "1" outputs = serving_utils.predict([inputs], problem, make_request_fn()) outputs, = outputs output, score = outputs new_file.write(output+'\n') print(output+'\n') except Exception as error: print("error: "+str(error)) print("error input: "+inputs) print("error output: "+output) new_file.close() return file
def grammar_check(): print("IN!") ''' data format: json example: { 'model': 't2t', 'content' 'This sentence would be checked by the model of grammar error correction.' } :return: json ''' #source = request.form['model'] #inputs = request.form['content'] source = 't2t' inputs = "I don't know what are you talking about." inputs = ' '.join([token.orth_ for token in nlp(inputs)]) print(inputs) origin = inputs inputs = bpe.process_line(inputs) print(inputs) outputs = serving_utils.predict([inputs], model_list[source].problem, model_list[source].request) outputs, = outputs output, score = outputs print(outputs) output = output[0:output.find('EOS') - 1] output = bpe_to_origin.bpe_to_origin_line(output) result = { 'corrected': output, 'origin': origin # 'origin': request.form['content'] } print(result) return result
def summary(): """Main prediction route. Provides a machine-generated summary of the given text. Sends a request to a live model trained on GitHub issues. """ global problem if problem is None: init() request_fn = make_tfserving_rest_request_fn(servable_name=servable_name, server=server) if request.method == 'POST': print("*** Using model %s for predictions ***", model_name) issue_text = request.form["issue_text"] issue_url = request.form["issue_url"] if issue_url: print("fetching issue from URL...") issue_text = get_issue_body(issue_url) tf.logging.info("issue_text: %s", issue_text) print("Running predict in line 120...") outputs = serving_utils.predict([issue_text], problem, request_fn) outputs, = outputs output, score = outputs tf.logging.info("output: %s : score %s", output, score) print("output: %s : score %s", output, score) return jsonify({'summary': output, 'body': issue_text}) return ('', 204)
def send_sentences_to_backend(self, sentences, src, tgt): if self.prefix_with: prefix = self.prefix_with.format(source=src, target=tgt) else: prefix = '' outputs = [] request_fn = serving_utils.make_grpc_request_fn( servable_name=self.model, timeout_secs=500, server=self.server) for batch in np.array_split( sentences, ceil(len(sentences) / current_app.config['BATCH_SIZE'])): try: outputs += list( map( lambda sent_score: sent_score[0], serving_utils.predict( [prefix + sent for sent in batch.tolist()], self.problem, request_fn))) except: # When tensorflow serving restarts web clients seem to "remember" the channel where # the connection have failed. clearing up the session, seems to solve that session.clear() raise return outputs
def predict_once(inputs): global request_fn global problem_hp outputs = serving_utils.predict([inputs], problem_hp, request_fn) outputs, = outputs output, score = outputs return output
def translate(): source = request.form['source'] inputs = request.form['content'] outputs = serving_utils.predict([inputs], lang_list[source].problem, lang_list[source].request) outputs, = outputs output, score = outputs return output[0:output.find('EOS') - 1]
def get_down_couplet(self, input_sentence_raw_list): input_sentence_list = self.format_input(input_sentence_raw_list) # do inference raw_outputs = serving_utils.predict(input_sentence_list, self.problem, self.request_fn) outputs = self.format_output(raw_outputs) return outputs
def query(self, str): """ :param input: str :return: """ inputs = re.split(self.delimiter, str) inputs = [" ".join(self.tokenizer(sentence)) for sentence in inputs] outputs = serving_utils.predict(inputs, self.problem, self.request_fn) outputs = [output[0].replace(" ", "") for output in outputs] return "".join(outputs)
def query_t2t(self, input_txt, data_dir, problem_name, server_name, server_address, t2t_usr_dir): usr_dir.import_usr_dir(t2t_usr_dir) problem = registry.problem(problem_name) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(data_dir)) problem.get_hparams(hparams) request_fn = make_request_fn(server_name, server_address) inputs = input_txt outputs = serving_utils.predict([inputs], problem, request_fn) output, score = outputs[0] return output, score
def translation(inputs): logger.info('translate sents:{}'.format(inputs)) start1 = time.time() outputs = serving_utils.predict(inputs, problem, request_fn) outputs = ''.join(result for result, score in outputs) outputs = outputs.replace('< / EOP >', '\r\n') end1 = time.time() logger.info('cline:%.2f' % (end1 - start1)) # return (print_str.format(inputs=input, output=output, score=score)) return outputs
def translate(self, inputs): # Registrierung der Problem-Klasse problem = registry.problem(self.problem) # Instanziierung des HPrams-Objekts hparams = HParams(data_dir=os.path.expanduser(self.data_dir)) problem.get_hparams(hparams) request_fn = self.make_request_fn() inputs = inputs # Prediction outputs = serving_utils.predict([inputs], problem, request_fn) outputs, = outputs output, score = outputs return {'inputs': inputs, 'outputs': output, 'scores': score}
def _translate(self, src, dodetok, dotok, dosegment): """Translate one sentence. @param src: source text (one sentence). @param dodetok: detokenize output? @param ret_src_tok: return tokenized source sentences? @param dotok: tokenize output? """ def _prepare_cmd(cmd, inputValue="", outputValue=""): SPACE_SPLIT_ELEMENT="SPACE_SPLIT_ELEMENT" cmd_args = cmd.replace(" ", SPACE_SPLIT_ELEMENT).replace('"$input"', inputValue).replace('"$output"', outputValue) return cmd_args.split(SPACE_SPLIT_ELEMENT) def _run_cmd(*args): try: out = subprocess.check_output(args[0]).strip() return 0, out except subprocess.CalledProcessError as grepexc: return grepexc.returncode, grepexc.output # tokenize src_tokenized = self.tokenizer.tokenize(src) if dotok else src if (self.preprocess): cmd_args = _prepare_cmd(self.preprocess, src) (cmd_error, cmd_output) = _run_cmd(cmd_args) if (cmd_error == 0): src_tokenized = cmd_output else: sys.stderr.write("{0}\n{1}".format(cmd_error, cmd_output)) # translate outputs = serving_utils.predict([src_tokenized], self.problem, self.request_fn) outputs, = outputs result, score = outputs if (self.postprocess): cmd_args = _prepare_cmd(self.postprocess, src, result) (cmd_error, cmd_output) = _run_cmd(cmd_args) if (cmd_error == 0): result = cmd_output else: sys.stderr.write("{0}\n{1}".format(cmd_error, cmd_output)) result = { 'src': src, 'translated': result }, score.item() return result
def _translate(self, src, dodetok, dotok, dosegment): """Translate one sentence. @param src: source text (one sentence). @param dodetok: detokenize output? @param ret_src_tok: return tokenized source sentences? @param dotok: tokenize output? """ def _prepare_cmd(cmd, inputValue="", outputValue=""): SPACE_SPLIT_ELEMENT = "SPACE_SPLIT_ELEMENT" cmd_args = cmd.replace(" ", SPACE_SPLIT_ELEMENT).replace( '"$input"', inputValue).replace('"$output"', outputValue) return cmd_args.split(SPACE_SPLIT_ELEMENT) def _run_cmd(*args): try: out = subprocess.check_output(args[0]).strip() return 0, out except subprocess.CalledProcessError as grepexc: return grepexc.returncode, grepexc.output # tokenize src_tokenized = self.tokenizer.tokenize(src) if dotok else src if (self.preprocess): cmd_args = _prepare_cmd(self.preprocess, src) (cmd_error, cmd_output) = _run_cmd(cmd_args) if (cmd_error == 0): src_tokenized = cmd_output else: sys.stderr.write("{0}\n{1}".format(cmd_error, cmd_output)) # translate outputs = serving_utils.predict([src_tokenized], self.problem, self.request_fn) outputs, = outputs result, score = outputs if (self.postprocess): cmd_args = _prepare_cmd(self.postprocess, src, result) (cmd_error, cmd_output) = _run_cmd(cmd_args) if (cmd_error == 0): result = cmd_output else: sys.stderr.write("{0}\n{1}".format(cmd_error, cmd_output)) result = {'src': src, 'translated': result}, score.item() return result
def main(_): tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = hparam.HParams(data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) request_fn = make_request_fn() while True: inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ") t1 = datetime.datetime.now() outputs = serving_utils.predict([inputs], problem, request_fn) t2 = datetime.datetime.now() time_taken_for_response = int((t2 - t1).total_seconds() * 1000) #print("time:", time_taken_for_response) outputs, = outputs output, score = outputs if len(score.shape) > 0: # pylint: disable=g-explicit-length-test print_str = """ Input: {inputs} Output (Scores [{score}]) (Time [{time}] milliseconds): {output} """ #time_taken_for_response = (t2 - t1) / 1000.0 score_text = ",".join(["{:.3f}".format(s) for s in score]) print( print_str.format(inputs=inputs, output=output, score=score_text, time=time_taken_for_response)) else: print_str = """ Input: {inputs} Output (Score {score:.3f}) (Time {time} milliseconds): {output} """ #time_taken_for_response = (t2 - t1) / 1000.0 print( print_str.format(inputs=inputs, output=output, score=score, time=time_taken_for_response)) if FLAGS.inputs_once: break
def batch_query_server(self, queries, server, hparams, timeout_secs=5): """Query a served model with a batch of multiple queries.""" problem = hparams.problem request_fn = serving_utils.make_grpc_request_fn( servable_name=self.model_name, server=server, timeout_secs=timeout_secs) responses = [] for query in queries: response = serving_utils.predict(query, problem, request_fn) responses.append(response) return responses
def grammar_check(inputs): print("IN!") ''' data format: json example: { 'model': 't2t', 'content' 'This sentence would be checked by the model of grammar error correction.' } return: { 'corrected': output, 'origin': origin } } ''' #source = request.form['model'] #inputs = request.form['content'] source = 't2t' #inputs = 'People get certain disease because of genetic changes.' #origin = inputs inputs = ' '.join([token.orth_ for token in nlp(inputs)]) print(inputs) inputs = bpe.process_line(inputs) origin = inputs print(inputs) outputs = serving_utils.predict([inputs], model_list[source].problem, model_list[source].request) outputs, = outputs output, score = outputs print(outputs) output = output[0:output.find('EOS') - 1] output = bpe_to_origin.bpe_to_origin_line(output) result = { 'corrected': output, 'origin': origin # 'origin': request.form['content'] } print(result) return result
def entry(inputs, input_data_dir, input_problem, input_serable_name, input_server): problem = registry.problem(input_problem) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(input_data_dir)) problem.get_hparams(hparams) request_fn = my_make_request_fn(input_serable_name, input_server) start = time.time() outputs = serving_utils.predict([inputs], problem, request_fn) end = time.time() print("predict cost time : %s s" % str(end - start)) only_one = outputs[0] res_content = only_one[0] res_score = str(only_one[1]) info = "input = %s , output = %s ( score : %s )" % (inputs, res_content, res_score) print(info) res = {"output": res_content, "input": inputs, "score": res_score} return res
def main(_): tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) request_fn = make_request_fn() while True: inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ") outputs = serving_utils.predict([inputs], problem, request_fn) print_str = """ Input: {inputs} Output: {outputs} """ print(print_str.format(inputs=inputs, outputs=outputs[0])) if FLAGS.inputs_once: break
def main(_): tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = HParams(data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) request_fn = make_request_fn() while True: inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ") outputs = serving_utils.predict([inputs], problem, request_fn) outputs, = outputs output, score = outputs if len(score.shape) > 0: # pylint: disable=g-explicit-length-test print_str = """ Input: {inputs} Output (Scores [{score}]): {output} """ score_text = ",".join(["{:.3f}".format(s) for s in score]) print( print_str.format(inputs=inputs, output=output, score=score_text)) else: print_str = """ Input: {inputs} Output (Score {score:.3f}): {output} """ print(print_str.format(inputs=inputs, output=output, score=score)) if FLAGS.inputs_once: break
def main(_): tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) request_fn = make_request_fn() if FLAGS.test_data: inputs = [] with open(FLAGS.test_data, 'r') as f: with open(FLAGS.output, 'w+') as fout: print("Id,Prediction", file=fout) for line in tqdm(f): num, text = line.rstrip().split(',', 1) outputs = serving_utils.predict([text], problem, request_fn) print('{},{}'.format( num, "-1" if outputs[0][0] == "neg" else "1"), file=fout) else: print("Missing test_data nd output file")
def main(): global i i += 1 form = ReviewForm(request.form) if request.method == "POST" and form.validate(): # 获取表单提交的英�? inputs = request.form["review"] # 获取评论的分类结�?类标、概�? # Y, lable_Y, proba = classify_review([review_text]) # 将概率保�?为小数并转换成为百分比的形式 # proba = float("%.4f" % proba) * 100 # 将分类结果返回给界面进行显示 inputs = mtr.truecase((mtok.tokenize(inputs, return_str=True)), return_str=True) inputs = sent_tokenizer.tokenize(inputs) a = datetime.now() outputs = serving_utils.predict(inputs, problem, request_fn) outputs = [output for (output, score) in outputs] outputs = (''.join(outputs)).replace(' ','') #print(outputs) b = datetime.now() return render_template("index.html", form=form, label=outputs + "�?Time�? + str((b - a).seconds) + "s �?) return render_template("index.html", form=form)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = HParams(data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) request_fn = make_request_fn() while True: inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ") if FLAGS.json: inputs = json.loads(inputs) ret = serving_utils.predict_features([inputs], problem, request_fn) outputs = ret["outputs"] else: outputs = serving_utils.predict([inputs], problem, request_fn) outputs, = outputs output, score = outputs if problem.multi_targets: print_str = """ Input: {inputs} Output (Score {score}): {output} """ else: print_str = """ Input: {inputs} Output (Score {score:.3f}): {output} """ print(print_str.format(inputs=inputs, output=output, score=score)) if FLAGS.inputs_once: break
def Translation(input): start = time.time() tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) request_fn = make_request_fn() # if FLAGS.word_cut: # input = " ".join(jieba.cut(input)) outputs = serving_utils.predict(input, problem, request_fn) print('outputs:',outputs) # outputs = '.'.join(result for result,score in outputs) for result, _ in outputs: yield result # outputs, = outputs # output, score = outputs # end = time.time() # print('client time:',(end - start)) print_str = """
def main(_): tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) request_fn = make_request_fn() while True: inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ") outputs = serving_utils.predict([inputs], problem, request_fn) outputs, = outputs output, score = outputs print_str = """ Input: {inputs} Output (Score {score:.3f}): {output} """ print(print_str.format(inputs=inputs, output=output, score=score)) if FLAGS.inputs_once: break
def main(_): tf.logging.set_verbosity(tf.logging.INFO) validate_flags() usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) problem = registry.problem(FLAGS.problem) hparams = tf.contrib.training.HParams( data_dir=os.path.expanduser(FLAGS.data_dir)) problem.get_hparams(hparams) if FLAGS.TFX == 1: os.chdir(T2T_Model_Path+"/4a_zh-tokenized-converted/"+FLAGS.subdir) files = [] for file in glob.glob("*.txt"): files.append(file) pool = Pool(number_of_workers) # we use imap_unordered as we don't care about order, we want the result of the # jobs as soon as they are done iter_ = pool.imap_unordered(job, files) while True: completed = [] while len(completed) < chunksize: # collect results from iterator until we reach the dispatch threshold # or until all jobs have been completed try: result = next(iter_) except StopIteration: print('all child jobs completed') # only break out of inner loop, might still be some completed # jobs to dispatch break except FailedJob as ex: print('processing of {} job failed'.format(ex.args[0])) sleep(300) headers = {'Content-type': 'application/json'} data = '{"SERVER ERROR: ":"'+str(FLAGS.subdir)+'-'+str(ex.args[0])+'"}' response = requests.post(slack_hook, headers=headers, data=data) else: completed.append(result) if completed: print('completed:', completed) if FLAGS.bleualign_upload == 1: for file in filter(None, completed): with codecs.open(T2T_Model_Path+"/4b_zh-tokenized-sample-en/"+file, mode="r", encoding='utf-8') as infile, codecs.open("/root/T2T_Model/temp"+FLAGS.subdir+".txt", mode="w", encoding='utf-8') as outfile: for line in infile: if not line.strip(): continue # skip the empty line outfile.write(line) # non-empty line. Write it to output copyfile(T2T_Model_Path+"/temp"+FLAGS.subdir+".txt", T2T_Model_Path+"/4b_zh-tokenized-sample-en/"+file) os.remove(T2T_Model_Path+"/temp"+FLAGS.subdir+".txt") # run bleualign cmd = "python3 '"+T2T_Model_Path+"/bleualign/bleualign.py' -v 0 -f sentences --filterthreshold 95 -s '"+T2T_Model_Path+"/4a_zh-tokenized-converted/"+file+"' -t '"+T2T_Model_Path+"/3_en-tokenized/"+file+"' --srctotarget '"+T2T_Model_Path+"/4b_zh-tokenized-sample-en/"+file+"' -o '"+T2T_Model_Path+"/5_aligned-zh/"+file+"'" os.system(cmd) completed_files = [] completed_files.append(T2T_Model_Path+"/5_aligned-zh/"+file[0:-4]+".txt-s") completed_files.append(T2T_Model_Path+"/5_aligned-zh/"+file[0:-4]+".txt-t") for file_large in completed_files: with codecs.open(file_large, mode="r", encoding='utf-8') as bigfile: for i, lines in enumerate(chunks(bigfile, max_lines)): file_split = '{}_{}.{}'.format(file_large.split('.')[0], i, file_large.split('.')[1]) with codecs.open(file_split, mode="w", encoding='utf-8') as f: f.writelines(lines) os.remove(file_large) # upload to s3 cmd = "aws s3 sync "+T2T_Model_Path+"/5_aligned-zh "+bleualign_upload_location os.system(cmd) if len(completed) < chunksize: print('all jobs completed and all job completion notifications' ' dispatched to central server') return # notify 'erik' on slack when done headers = {'Content-type': 'application/json'} data = '{"SERVER DONE: ":"'+str(FLAGS.subdir)+'"}' response = requests.post(slack_hook, headers=headers, data=data) else: while True: inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ") outputs = serving_utils.predict([inputs], problem, make_request_fn()) outputs, = outputs output, score = outputs print_str = """ Input: {inputs} Output (Score {score:.3f}): {output} """ print(print_str.format(inputs=inputs, output=output, score=score)) if FLAGS.inputs_once: break