예제 #1
0
    def _run(self, sess, model_file, embeddings_set, use_preproc=True):
        indices, vocabs = self._create_vocabs(model_file)
        extra_features_required = [
            x for x in vocabs.keys()
            if x not in TensorFlowExporter.DEFAULT_VOCABS
        ]

        self.assign_char_lookup()

        labels = self.load_labels(model_file)
        mxlen, mxwlen = self._get_max_lens(model_file)

        model_params = self.task.config_params["model"]

        if use_preproc:
            serialized_tf_example, tf_example, raw_posts, lengths = self._run_preproc(
                model_params, vocabs, model_file, indices,
                extra_features_required)

        model_params["pkeep"] = 1
        model_params["sess"] = sess
        model_params["maxs"] = mxlen
        model_params["maxw"] = mxwlen
        print(model_params)

        embeddings = self._initialize_embeddings_map(vocabs, embeddings_set)
        model = baseline.tf.classify.create_model(embeddings, labels,
                                                  **model_params)
        softmax_output = tf.nn.softmax(model.logits)

        values, indices = tf.nn.top_k(softmax_output, len(labels))
        class_tensor = tf.constant(model.labels)
        table = tf.contrib.lookup.index_to_string_table_from_tensor(
            class_tensor)
        classes = table.lookup(tf.to_int64(indices))
        self.restore_model(sess, model_file)

        if use_preproc:
            sig_input = SignatureInput(serialized_tf_example, tf_example,
                                       extra_features_required)
        else:
            sig_input = SignatureInput(None,
                                       None,
                                       extra_features_required,
                                       model=model)

        sig_output = SignatureOutput(classes, values)

        return sig_input, sig_output, 'predict_text'
예제 #2
0
    def _run(self, sess, model_file, embeddings_set, use_preproc=True):
        mxlen, mxwlen = self._get_max_lens(model_file)
        indices, vocabs = self._create_vocabs(model_file)
        self.assign_char_lookup()

        labels = self.load_labels(model_file)

        extra_features_required = [
            x for x in vocabs.keys()
            if x not in TensorFlowExporter.DEFAULT_VOCABS
        ]
        model_params = self.task.config_params["model"]

        lengths = []

        if use_preproc:
            serialized_tf_example, tf_example, raw_posts, lengths = self._run_preproc(
                model_params, vocabs, model_file, indices,
                extra_features_required)
            model_params["lengths"] = lengths

        model_params["pkeep"] = 1
        model_params["sess"] = sess
        model_params["maxs"] = mxlen
        model_params["maxw"] = mxwlen
        model_params['span_type'] = self.task.config_params['train'].get(
            'span_type')
        print(model_params)

        classes, values, model = self._create_model(vocabs, labels,
                                                    embeddings_set, mxlen,
                                                    model_params)
        self.restore_model(sess, model_file)

        if use_preproc:
            sig_input = SignatureInput(serialized_tf_example, tf_example,
                                       extra_features_required)
        else:
            sig_input = SignatureInput(None,
                                       None,
                                       extra_features_required + ['lengths'],
                                       model=model)

        sig_output = SignatureOutput(classes, values)

        return sig_input, sig_output, 'tag_text'
예제 #3
0
    def _run(self, sess, model_file, embeddings_set):

        self.word2input, vocab1 = Seq2SeqTensorFlowExporter.read_input_vocab(
            model_file)
        self.output2word, vocab2 = Seq2SeqTensorFlowExporter.read_output_vocab(
            model_file)

        # Make the TF example, network input
        serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
        feature_configs = {
            FIELD_NAME: tf.FixedLenFeature(shape=[], dtype=tf.string),
        }
        tf_example = tf.parse_example(serialized_tf_example, feature_configs)
        raw_posts = tf_example[FIELD_NAME]

        # Run for each post
        dense, length = tf.map_fn(self._preproc_post_creator(),
                                  raw_posts,
                                  dtype=(tf.int32, tf.int32))

        model_params = self.task.config_params["model"]
        model_params["dsz"] = self.get_dsz(embeddings_set)
        model_params["src"] = dense
        model_params["src_len"] = length
        model_params["mx_tgt_len"] = self.task.config_params["preproc"][
            "mxlen"]
        model_params["tgt_len"] = 1
        model_params["pkeep"] = 1
        model_params["sess"] = sess
        model_params["predict"] = True
        print(model_params)
        model = baseline.tf.seq2seq.create_model(vocab1, vocab2,
                                                 **model_params)
        output = self.output2word.lookup(tf.cast(model.best, dtype=tf.int64))

        self.restore_model(sess, model_file)

        sig_input = SignatureInput(serialized_tf_example, raw_posts)
        sig_output = SignatureOutput(classes, None)

        return sig_input, sig_output, 'suggest_text'
예제 #4
0
    def _run(self, sess, model_file, embeddings_set, output_dir, model_version, use_preproc=True):

        indices, vocabs = self._create_vocabs(model_file)

        self.assign_char_lookup()

        labels = self.load_labels(model_file)
        extra_features_required = []

        # Make the TF example, network input
        serialized_tf_example, tf_example = self._create_example(extra_features_required)
        raw_posts = tf_example[FIELD_NAME]

        mxlen, mxwlen = self._get_max_lens(model_file)

        preprocessor = ElmoPreprocessorCreator(
            indices, self.lchars, self.upchars_lut,
            self.task, FIELD_NAME, extra_features_required, mxlen, mxwlen
        )

        types = {k: tf.int64 for k in indices.keys()}
        types.update({'mixed_case_word': tf.int64})
        # Run for each post

        preprocessed, lengths = tf.map_fn(preprocessor.preproc_post, tf_example,
                                          dtype=(types, tf.int32),
                                          back_prop=False)
        embeddings = self._initialize_embeddings_map(vocabs, embeddings_set)

        model_params = self.task.config_params["model"]
        model_params["x"] = preprocessed['mixed_case_word']
        model_params["xch"] = preprocessed['char']
        model_params["x_lc"] = preprocessed['word']

        model_params["lengths"] = lengths
        model_params["pkeep"] = 1
        model_params["sess"] = sess
        model_params["maxs"] = mxlen
        model_params["maxw"] = mxwlen
        model_params['span_type'] = self.task.config_params['train'].get('span_type')
        print(model_params)
        model = baseline.tf.tagger.create_model(labels, embeddings, **model_params)
        model.create_loss()

        softmax_output = tf.nn.softmax(model.probs)
        values, indices = tf.nn.top_k(softmax_output, 1)

        start_np = np.full((1, 1, len(labels)), -1e4, dtype=np.float32)
        start_np[:, 0, labels['<GO>']] = 0
        start = tf.constant(start_np)
        model.probs = tf.concat([start, model.probs], 1)

        if model.crf is True:
            indices, _ = tf.contrib.crf.crf_decode(model.probs, model.A, tf.constant([mxlen + 1]))## We are assuming the batchsz is 1 here
            indices = indices[:, 1:]

        list_of_labels = [''] * len(labels)
        for label, idval in labels.items():
            list_of_labels[idval] = label

        class_tensor = tf.constant(list_of_labels)
        table = tf.contrib.lookup.index_to_string_table_from_tensor(class_tensor)
        classes = table.lookup(tf.to_int64(indices))
        self.restore_checkpoint(sess, model_file)

        sig_input = SignatureInput(None, raw_posts)
        sig_output = SignatureOutput(classes, values)

        return sig_input, sig_output, "tag_text"