def _run(self, sess, model_file, embeddings_set, use_preproc=True): indices, vocabs = self._create_vocabs(model_file) extra_features_required = [ x for x in vocabs.keys() if x not in TensorFlowExporter.DEFAULT_VOCABS ] self.assign_char_lookup() labels = self.load_labels(model_file) mxlen, mxwlen = self._get_max_lens(model_file) model_params = self.task.config_params["model"] if use_preproc: serialized_tf_example, tf_example, raw_posts, lengths = self._run_preproc( model_params, vocabs, model_file, indices, extra_features_required) model_params["pkeep"] = 1 model_params["sess"] = sess model_params["maxs"] = mxlen model_params["maxw"] = mxwlen print(model_params) embeddings = self._initialize_embeddings_map(vocabs, embeddings_set) model = baseline.tf.classify.create_model(embeddings, labels, **model_params) softmax_output = tf.nn.softmax(model.logits) values, indices = tf.nn.top_k(softmax_output, len(labels)) class_tensor = tf.constant(model.labels) table = tf.contrib.lookup.index_to_string_table_from_tensor( class_tensor) classes = table.lookup(tf.to_int64(indices)) self.restore_model(sess, model_file) if use_preproc: sig_input = SignatureInput(serialized_tf_example, tf_example, extra_features_required) else: sig_input = SignatureInput(None, None, extra_features_required, model=model) sig_output = SignatureOutput(classes, values) return sig_input, sig_output, 'predict_text'
def _run(self, sess, model_file, embeddings_set, use_preproc=True): mxlen, mxwlen = self._get_max_lens(model_file) indices, vocabs = self._create_vocabs(model_file) self.assign_char_lookup() labels = self.load_labels(model_file) extra_features_required = [ x for x in vocabs.keys() if x not in TensorFlowExporter.DEFAULT_VOCABS ] model_params = self.task.config_params["model"] lengths = [] if use_preproc: serialized_tf_example, tf_example, raw_posts, lengths = self._run_preproc( model_params, vocabs, model_file, indices, extra_features_required) model_params["lengths"] = lengths model_params["pkeep"] = 1 model_params["sess"] = sess model_params["maxs"] = mxlen model_params["maxw"] = mxwlen model_params['span_type'] = self.task.config_params['train'].get( 'span_type') print(model_params) classes, values, model = self._create_model(vocabs, labels, embeddings_set, mxlen, model_params) self.restore_model(sess, model_file) if use_preproc: sig_input = SignatureInput(serialized_tf_example, tf_example, extra_features_required) else: sig_input = SignatureInput(None, None, extra_features_required + ['lengths'], model=model) sig_output = SignatureOutput(classes, values) return sig_input, sig_output, 'tag_text'
def _run(self, sess, model_file, embeddings_set): self.word2input, vocab1 = Seq2SeqTensorFlowExporter.read_input_vocab( model_file) self.output2word, vocab2 = Seq2SeqTensorFlowExporter.read_output_vocab( model_file) # Make the TF example, network input serialized_tf_example = tf.placeholder(tf.string, name='tf_example') feature_configs = { FIELD_NAME: tf.FixedLenFeature(shape=[], dtype=tf.string), } tf_example = tf.parse_example(serialized_tf_example, feature_configs) raw_posts = tf_example[FIELD_NAME] # Run for each post dense, length = tf.map_fn(self._preproc_post_creator(), raw_posts, dtype=(tf.int32, tf.int32)) model_params = self.task.config_params["model"] model_params["dsz"] = self.get_dsz(embeddings_set) model_params["src"] = dense model_params["src_len"] = length model_params["mx_tgt_len"] = self.task.config_params["preproc"][ "mxlen"] model_params["tgt_len"] = 1 model_params["pkeep"] = 1 model_params["sess"] = sess model_params["predict"] = True print(model_params) model = baseline.tf.seq2seq.create_model(vocab1, vocab2, **model_params) output = self.output2word.lookup(tf.cast(model.best, dtype=tf.int64)) self.restore_model(sess, model_file) sig_input = SignatureInput(serialized_tf_example, raw_posts) sig_output = SignatureOutput(classes, None) return sig_input, sig_output, 'suggest_text'
def _run(self, sess, model_file, embeddings_set, output_dir, model_version, use_preproc=True): indices, vocabs = self._create_vocabs(model_file) self.assign_char_lookup() labels = self.load_labels(model_file) extra_features_required = [] # Make the TF example, network input serialized_tf_example, tf_example = self._create_example(extra_features_required) raw_posts = tf_example[FIELD_NAME] mxlen, mxwlen = self._get_max_lens(model_file) preprocessor = ElmoPreprocessorCreator( indices, self.lchars, self.upchars_lut, self.task, FIELD_NAME, extra_features_required, mxlen, mxwlen ) types = {k: tf.int64 for k in indices.keys()} types.update({'mixed_case_word': tf.int64}) # Run for each post preprocessed, lengths = tf.map_fn(preprocessor.preproc_post, tf_example, dtype=(types, tf.int32), back_prop=False) embeddings = self._initialize_embeddings_map(vocabs, embeddings_set) model_params = self.task.config_params["model"] model_params["x"] = preprocessed['mixed_case_word'] model_params["xch"] = preprocessed['char'] model_params["x_lc"] = preprocessed['word'] model_params["lengths"] = lengths model_params["pkeep"] = 1 model_params["sess"] = sess model_params["maxs"] = mxlen model_params["maxw"] = mxwlen model_params['span_type'] = self.task.config_params['train'].get('span_type') print(model_params) model = baseline.tf.tagger.create_model(labels, embeddings, **model_params) model.create_loss() softmax_output = tf.nn.softmax(model.probs) values, indices = tf.nn.top_k(softmax_output, 1) start_np = np.full((1, 1, len(labels)), -1e4, dtype=np.float32) start_np[:, 0, labels['<GO>']] = 0 start = tf.constant(start_np) model.probs = tf.concat([start, model.probs], 1) if model.crf is True: indices, _ = tf.contrib.crf.crf_decode(model.probs, model.A, tf.constant([mxlen + 1]))## We are assuming the batchsz is 1 here indices = indices[:, 1:] list_of_labels = [''] * len(labels) for label, idval in labels.items(): list_of_labels[idval] = label class_tensor = tf.constant(list_of_labels) table = tf.contrib.lookup.index_to_string_table_from_tensor(class_tensor) classes = table.lookup(tf.to_int64(indices)) self.restore_checkpoint(sess, model_file) sig_input = SignatureInput(None, raw_posts) sig_output = SignatureOutput(classes, values) return sig_input, sig_output, "tag_text"