示例#1
0
def create_gends_iterator_from_file(hparams,
                                    entity_vocab_table,
                                    is_eval=False):

    src_vocab_table, tgt_vocab_table = model_helper.create_vocab_from_file(
        hparams['src_vocab'], hparams['tgt_vocab'], hparams['share_vocab'])
    entity_vocab_table, _ = model_helper.create_vocab_from_file(
        hparams['entity_path'], hparams['entity_path'], True)
    union_vocab_table, _ = model_helper.create_vocab_from_file(
        hparams['uni_vocab'], hparams['uni_vocab'], True)
    # relative_vocab = src_vpcab_table + ENT0_ENT5
    relative_vocab_table, _ = model_helper.create_vocab_from_file(
        hparams['relative_vocab'], hparams['relative_vocab'], True)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        hparams['relative_vocab'], default_value=vocab_utils.UNK)
    reverse_uni_vocab_table = lookup_ops.index_to_string_table_from_file(
        hparams['uni_vocab'], default_value=vocab_utils.UNK)

    src_file_placeholder = tf.placeholder(dtype=tf.string,
                                          shape=[],
                                          name='src_placeholder')
    entity_file_placeholder = tf.placeholder(dtype=tf.string,
                                             shape=[],
                                             name='src_entity_placeholder')
    tgt_in_file_placeholder = tf.placeholder(dtype=tf.string,
                                             shape=[],
                                             name='tgt_in_placeholder')
    tgt_out_file_placeholder = tf.placeholder(dtype=tf.string,
                                              shape=[],
                                              name='tgt_out_placeholder')
    fact_file_placeholder = tf.placeholder(dtype=tf.string,
                                           shape=[],
                                           name='fact_placeholder')
    src_dataset = tf.data.TextLineDataset(src_file_placeholder)
    ent_dataset = tf.data.TextLineDataset(entity_file_placeholder)
    tgt_in_dataset = tf.data.TextLineDataset(tgt_in_file_placeholder)
    tgt_out_dataset = tf.data.TextLineDataset(tgt_out_file_placeholder)
    fact_dataset = tf.data.TextLineDataset(fact_file_placeholder)
    skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
    num_buckets = hparams['num_buckets'] if not is_eval else 1
    iterator = iterator_utils.get_iterator(src_dataset,
                                           ent_dataset,
                                           tgt_in_dataset,
                                           tgt_out_dataset,
                                           fact_dataset,
                                           src_vocab_table,
                                           entity_vocab_table,
                                           union_vocab_table,
                                           relative_vocab_table,
                                           random_seed=hparams['random_seed'],
                                           num_buckets=num_buckets,
                                           batch_size=hparams['batch_size'],
                                           sos=vocab_utils.SOS,
                                           eos=vocab_utils.EOS,
                                           src_max_len=hparams['src_max_len'],
                                           tgt_max_len=hparams['tgt_max_len'],
                                           shuffle=not is_eval,
                                           skip_count=skip_count_placeholder)

    return iterator, skip_count_placeholder, src_file_placeholder, entity_file_placeholder, tgt_in_file_placeholder, tgt_out_file_placeholder, fact_file_placeholder, src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table, reverse_uni_vocab_table
示例#2
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    with graph.as_default(), tf.container(scope or "infer"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        reverse_src_vocab_table = lookup_ops.index_to_string_table_from_file(
            src_vocab_file, default_value=vocab_utils.UNK)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(
            src_placeholder)
        tgt_dataset = tf.data.Dataset.from_tensor_slices(
            tgt_placeholder)
        # set real values
        iterator_src = iterator_utils.get_infer_iterator(
            src_dataset,
            src_vocab_table,
            batch_size=batch_size_placeholder,
            eos=hparams.eos,
            src_max_len=hparams.src_max_len_infer)
        iterator_tgt = iterator_utils.get_infer_iterator(
            tgt_dataset,
            tgt_vocab_table,
            batch_size=batch_size_placeholder,
            eos=hparams.eos,
            src_max_len=hparams.tgt_max_len_infer)
        model = model_creator(
            hparams,
            iterator_s2s=iterator_src,
            iterator_s2t=iterator_src,
            iterator_t2t=iterator_tgt,
            iterator_t2s=iterator_tgt,
            mode=tf.contrib.learn.ModeKeys.INFER,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            reverse_source_vocab_table=reverse_src_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            scope=scope,
            extra_args=extra_args)
    return InferModel(
        graph=graph,
        model=model,
        src_placeholder=src_placeholder,
        tgt_placeholder=tgt_placeholder,
        batch_size_placeholder=batch_size_placeholder,
        iterator_src=iterator_src,
        iterator_tgt=iterator_tgt)
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    lbl_vocab_file = hparams.lbl_vocab_file
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        src_vocab_table, tgt_vocab_table, lbl_vocab_table = \
          vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file,
                                          lbl_vocab_file, hparams.share_vocab)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)
        reverse_lbl_vocab_table = lookup_ops.index_to_string_table_from_file(
            lbl_vocab_file, default_value=vocab_utils.UNK)

        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        lbl_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_dataset = tf.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
        lbl_dataset = tf.data.TextLineDataset(lbl_file_placeholder)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            lbl_dataset,
            src_vocab_table,
            tgt_vocab_table,
            lbl_vocab_table,
            hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len_infer,
            tgt_max_len=hparams.tgt_max_len_infer)
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.EVAL,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            label_vocab_table=lbl_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            reverse_target_intent_vocab_table=reverse_lbl_vocab_table,
            scope=scope,
            extra_args=extra_args)
    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_file_placeholder,
                     tgt_file_placeholder=tgt_file_placeholder,
                     lbl_file_placeholder=lbl_file_placeholder,
                     iterator=iterator)
示例#4
0
    def __init__(self, hparams, training=True):

        self.training = training
        self.hparams = hparams

        self.src_max_len = self.hparams.src_max_len
        self.tgt_max_len = self.hparams.tgt_max_len

        self.vocab_size, self.vocab_list = check_vocab(VOCAB_FILE)
        self.emotion_size, self.emotion_list = check_vocab(EMOTION_FILE)

        self.vocab_table = lookup_ops.index_table_from_file(
            VOCAB_FILE, default_value=self.hparams.unk_id)
        self.reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            VOCAB_FILE, default_value=self.hparams.unk_token)

        self.emotion_table = lookup_ops.index_table_from_file(
            EMOTION_FILE, default_value=self.hparams.unk_id)
        self.reverse_emotion_table = lookup_ops.index_to_string_table_from_file(
            EMOTION_FILE, default_value=self.hparams.unk_token)

        if self.training:
            print('--------------------------------------------------')
            for index, name in enumerate(RECORD_FILE_NAME_LIST):
                print('= {} - {}'.format(index, name))
            RECORD_INDEX = int(input("# Input record file index: "))
            print('--------------------------------------------------')

            batch_lists = self.get_file_batch_lists('{}_train.json'.format(
                RECORD_FILE_NAME_LIST[RECORD_INDEX]))
            emotion_num_dict = self.get_emotion_num(batch_lists)
            self.emotion_weight_dict = self.get_emotion_weight(
                emotion_num_dict)

            self.case_table = prepare_case_table()
            self.dev_dataset = self.load_record(
                os.path.join(
                    RECORD_DIR, '{}_dev.tfrecords'.format(
                        RECORD_FILE_NAME_LIST[RECORD_INDEX])))
            self.test_dataset = self.load_record(
                os.path.join(
                    RECORD_DIR, '{}_test.tfrecords'.format(
                        RECORD_FILE_NAME_LIST[RECORD_INDEX])))
            self.train_dataset = self.load_record(
                os.path.join(
                    RECORD_DIR, '{}_train.tfrecords'.format(
                        RECORD_FILE_NAME_LIST[RECORD_INDEX])))
        else:
            self.case_table = None
示例#5
0
    def __init__(self, dataset_dir, hparams=None, training=True):

        if hparams is None:
            self.hparams = utils.load_hparams(dataset_dir)
        else:
            self.hparams = hparams

        self.src_max_len = self.hparams.src_max_len
        self.tgt_max_len = self.hparams.tgt_max_len
        self.dataset = None
        self.dataset_ids = None

        src_vocab_file = os.path.join(dataset_dir, 'vocab.enc')
        tgt_vocab_file = os.path.join(dataset_dir, 'vocab.dec')

        self.src_vocab_size, _ = check_vocab(src_vocab_file)
        self.tgt_vocab_size, _ = check_vocab(tgt_vocab_file)

        self.src_vocab_table, self.tgt_vocab_table = create_vocab_tables(
            src_vocab_file, tgt_vocab_file, self.hparams.unk_id)
        if training:
            self.reverse_vocab_table = None
        else:
            self.reverse_vocab_table = \
                lookup_ops.index_to_string_table_from_file(
                    tgt_vocab_file, default_value=self.hparams.unk_token)
            with open(os.path.join(dataset_dir, 'test.enc'), 'r') as enc, \
                    open(os.path.join(dataset_dir, 'test.dec'), 'r') as dec:
                self.sample_src_data = enc.read().splitlines()
                self.sample_tgt_data = dec.read().splitlines()

        self._load_dataset(dataset_dir)
        self._convert_to_tokens()
示例#6
0
 def __init__(self):
     TreeHeight = lambda x: int(math.log(x - 1) / math.log(2)) + 2
     indexCnt = count_idx(FLAGS.input_previous_model_path + "/" +
                          FLAGS.tree_index_file)
     self.tree_height = TreeHeight(indexCnt + 1)
     self.tree_index = lookup_ops.index_table_from_file(
         FLAGS.input_previous_model_path + "/" + FLAGS.tree_index_file,
         default_value=indexCnt)
     self.reverse_tree_index = lookup_ops.index_to_string_table_from_file(
         FLAGS.input_previous_model_path + "/" + FLAGS.tree_index_file,
         default_value='<unk>')
     self.dims = parse_dims(FLAGS.semantic_model_dims)
     self.layer_embedding = tf.get_variable(
         name='tree_node_emb',
         shape=[pow(2, self.tree_height - 1), self.dims[-1]])
     if not FLAGS.leaf_content_emb:
         self.leaf_embedding = tf.get_variable(
             name='leaf_node_emb',
             shape=[pow(2, self.tree_height - 1), self.dims[-1]])
     if FLAGS.use_mstf_ops == 1:
         self.op_dict = mstf.dssm_dict(FLAGS.xletter_dict)
     elif FLAGS.use_mstf_ops == -1:
         self.op_dict = XletterPreprocessor(FLAGS.xletter_dict,
                                            FLAGS.xletter_win_size)
     else:
         self.op_dict = None
示例#7
0
 def assets_module_fn():
     indices = tf.compat.v1.placeholder(dtype=tf.int64,
                                        name="indices")
     table = index_to_string_table_from_file(
         vocabulary_file=vocab_filename, default_value="UNKNOWN")
     outputs = table.lookup(indices)
     hub.add_signature(inputs=indices, outputs=outputs)
示例#8
0
    def build_graph_dist_strategy(self, features, labels, mode, params):
        """Model function."""
        del labels, params
        misc_utils.print_out("Running dist_strategy mode_fn")

        hparams = self.hparams

        # Create a GNMT model for training.
        # assert (hparams.encoder_type == "gnmt" or
        #        hparams.attention_architecture in ["gnmt", "gnmt_v2"])
        with mixed_precision_scope():
            model = gnmt_model.GNMTModel(hparams, mode=mode, features=features)
            if mode == tf.contrib.learn.ModeKeys.INFER:
                sample_ids = model.sample_id
                reverse_target_vocab_table = lookup_ops.index_to_string_table_from_file(
                    hparams.tgt_vocab_file, default_value=vocab_utils.UNK)
                sample_words = reverse_target_vocab_table.lookup(
                    tf.to_int64(sample_ids))
                # make sure outputs is of shape [batch_size, time] or [beam_width,
                # batch_size, time] when using beam search.
                if hparams.time_major:
                    sample_words = tf.transpose(sample_words)
                elif sample_words.shape.ndims == 3:
                    # beam search output in [batch_size, time, beam_width] shape.
                    sample_words = tf.transpose(sample_words, [2, 0, 1])
                predictions = {"predictions": sample_words}
                # return loss, vars, grads, predictions, train_op, scaffold
                return None, None, None, predictions, None, None
            elif mode == tf.contrib.learn.ModeKeys.TRAIN:
                loss = model.train_loss
                train_op = model.update
                return loss, model.params, model.grads, None, train_op, None
            else:
                raise ValueError("Unknown mode in model_fn: %s" % mode)
def create_infer_graph(scope=None):
    graph = tf.Graph()
    with graph.as_default():
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(hparams.src_vocab_file,
                                                                           hparams.tgt_vocab_file,
                                                                           hparams.share_vocab)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(hparams.tgt_vocab_file,
                                                                             default_value=vocab_utils.UNK)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)
        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)
        iterator = iterator_utils.get_infer_iterator(src_dataset,
                                                     src_vocab_table,
                                                     batch_size=batch_size_placeholder,
                                                     eos=hparams.eos,
                                                     src_max_len=hparams.src_max_len_infer)
        final_model = model.Model(hparams,
                                  iterator=iterator,
                                  mode=tf.contrib.learn.ModeKeys.INFER,
                                  source_vocab_table=src_vocab_table,
                                  target_vocab_table=tgt_vocab_table,
                                  reverse_target_vocab_table=reverse_tgt_vocab_table,
                                  scope=scope)
    return InferGraph(graph=graph,
                      model=final_model,
                      src_placeholder=src_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      iterator=iterator)
示例#10
0
def create_serve_model(model_creator, hparams, scope=None, extra_args=None):
    graph = tf.Graph()
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    with graph.as_default(), tf.container(scope or "serve"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)

        # src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        src_placeholder = tf.placeholder(dtype=tf.string,
                                         name="src_placeholder")
        batch_size_placeholder = tf.constant(1, tf.int64)

        iterator = pre_process_generative_model(
            src_placeholder,
            src_vocab_table,
            eos=hparams.eos,
            src_max_len=hparams.src_max_len_infer)
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.INFER,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            scope=scope,
            extra_args=extra_args)
    return InferModel(graph=graph,
                      model=model,
                      src_placeholder=src_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      iterator=iterator)
示例#11
0
def create_infer_model(hparams):
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    graph = tf.Graph()
    with graph.as_default(), tf.container('infer'):
        src_vocab_table = lookup_ops.index_table_from_file(
            src_vocab_file, default_value=UNK_ID)
        tgt_vocab_table = lookup_ops.index_table_from_file(
            tgt_vocab_file, default_value=UNK_ID)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=UNK)
        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)
        iterator = get_infer_iterator(src_dataset,
                                      src_vocab_table,
                                      batch_size_placeholder,
                                      EOS,
                                      src_max_len=hparams.src_max_len_infer)
        model = NMTModel(hparams, 'infer', iterator, src_vocab_table,
                         tgt_vocab_table, reverse_tgt_vocab_table)
        return InferModel(graph=graph,
                          model=model,
                          src_placeholder=src_placeholder,
                          batch_size_placeholder=batch_size_placeholder,
                          iterator=iterator)
示例#12
0
    def predict(self, sess, hparams):
        # 获取原文本的iterator
        file_iter = file_content_iterator(hparams.pred)
        tag_table = lookup_ops.index_to_string_table_from_file(
            hparams.tgt_vocab_file, default_value='<tag-unknown>')
        self.create_or_load(sess, hparams.out_dir_model)
        while True:
            try:
                tf_viterbi_sequence = sess.run(self.viterbi_sequence)[0]

            except tf.errors.OutOfRangeError:
                print('Prediction finished!')
                break

            tags = []
            for id in tf_viterbi_sequence:
                tags.append(
                    sess.run(tag_table.lookup(tf.constant(id,
                                                          dtype=tf.int64))))
            # write_result_to_file(file_iter, tags)
            raw_content = next(file_iter)
            words = raw_content.split(' ')
            assert len(words) == len(tags)
            for w, t in zip(words, tags):
                print(w, '(' + t.decode("UTF-8") + ')')
            print()
            print('*' * 100)
示例#13
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()
    tgt_vocab_file = hparams.tgt_vocab_file

    with graph.as_default(), tf.container(scope or "infer"):
        tgt_vocab_table = vocab_utils.create_vocab_tables(tgt_vocab_file)
        aa_weight_table = input_config.create_aa_tables()
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)
        iterator = iterator_utils.get_infer_iterator(
            src_dataset,
            tgt_vocab_table,
            batch_size=batch_size_placeholder,
            aa_weight_table=aa_weight_table,
        )
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.INFER,
            aa_weight_table=aa_weight_table,
            target_vocab_table=tgt_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            scope=scope,
            extra_args=extra_args)
    return InferModel(graph=graph,
                      model=model,
                      src_placeholder=src_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      iterator=iterator)
示例#14
0
    def __init__(self,
                 hparams,
                 tokenizer=None,
                 training=True,
                 mode='inference'):

        self.training = training
        self.hparams = hparams

        self.tokenizer = tokenizer if tokenizer else Tokenizer(
            self.hparams, VOCAB_FILE)
        self.vocab_size, self.vocab_dict = len(
            self.tokenizer.vocab), self.tokenizer.vocab

        self.emotion_tokenizer = tokenizer if tokenizer else Tokenizer(
            self.hparams, EMOTION_FILE)
        self.emotion_size, self.emotion_list = len(
            self.emotion_tokenizer.vocab), self.emotion_tokenizer.inv_vocab

        with tf.name_scope("data_process"):

            self.vocab_table = lookup_ops.index_table_from_file(
                VOCAB_FILE, default_value=self.hparams.unk_id)
            self.reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
                VOCAB_FILE, default_value=self.hparams.unk_token)

            self.emotion_table = lookup_ops.index_table_from_file(
                EMOTION_FILE, default_value=self.hparams.unk_id)
            self.reverse_emotion_table = lookup_ops.index_to_string_table_from_file(
                EMOTION_FILE, default_value=self.hparams.unk_token)

            self.dull_response_id = self.get_dull_response(DULL_RESPONSE)

        if self.training:
            with tf.name_scope("load_record"):
                if mode == 'ddpg':
                    train_file = 'daily_train.tfrecords'
                    test_file = 'daily_test.tfrecords'
                else:
                    train_file = 'daily_mtem_train.tfrecords'
                    test_file = 'daily_mtem_test.tfrecords'
#                    train_file = 'friends_train.tfrecords'
#                    test_file = 'friends_test.tfrecords'
                self.train_dataset_count, self.train_dataset = self.load_record(
                    os.path.join(RECORD_DIR, train_file), ELEMENT_LIST)
                self.test_dataset_count, self.test_dataset = self.load_record(
                    os.path.join(RECORD_DIR, test_file), ELEMENT_LIST)
示例#15
0
文件: driver.py 项目: piBrain/aura-ml
 def prepare_predict(self, sample_id, beam_width):
     rev_table = lookup_ops.index_to_string_table_from_file(
         self.vocab_path, default_value=UNK)
     predictions = rev_table.lookup(tf.to_int64(sample_id))
     return tf.estimator.EstimatorSpec(
         predictions=predictions,
         mode=tf.estimator.ModeKeys.PREDICT
     )
示例#16
0
def create_train_model(hparams):
    # get src/tgt vocabulary table

    graph = tf.Graph()
    with graph.as_default() as graph:
        src_vocab_table, tgt_vocab_table = vocab_table_util.get_vocab_table(
            hparams.src_vocab_file, hparams.tgt_vocab_file)
        reversed_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.tgt_vocab_file, default_value=vocab_table_util.UNK)
        reversed_src_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.src_vocab_file, default_value=vocab_table_util.UNK)
        with tf.variable_scope("NMTModel",
                               initializer=tf.truncated_normal_initializer(
                                   stddev=0.01)) as nmtmodel_scope:
            with tf.variable_scope("train_iterator"):
                src_dataset_file = "%s.%s" % (hparams.train_prefix,
                                              hparams.src)
                tgt_dataset_file = "%s.%s" % (hparams.train_prefix,
                                              hparams.tgt)
                iterator = iterator_utils.get_nmt_iterator(
                    src_dataset_file, tgt_dataset_file, src_vocab_table,
                    tgt_vocab_table, hparams.batch_size, hparams.eos,
                    hparams.sos, hparams.source_reverse, hparams.random_seed)
            with tf.variable_scope("shared_encoder") as encoder_scope:
                encoder = en.Encoder(hparams,
                                     tf.contrib.learn.ModeKeys.TRAIN,
                                     dtype=tf.float32,
                                     scope=encoder_scope)
            with tf.variable_scope("shared_decoder") as decoder_scope:
                decoder = de.Decoder(hparams,
                                     tf.contrib.learn.ModeKeys.TRAIN,
                                     dtype=tf.float32,
                                     scope=decoder_scope)
            nmt_model = mdl.NMTModel(hparams, src_vocab_table, tgt_vocab_table,
                                     encoder, decoder, iterator,
                                     tf.contrib.learn.ModeKeys.TRAIN,
                                     reversed_tgt_vocab_table,
                                     reversed_src_vocab_table)
        saver = tf.train.Saver(tf.global_variables())
    return TrainModel(graph=graph,
                      model=nmt_model,
                      encoder=encoder,
                      decoder=decoder,
                      iterator=iterator,
                      saver=saver)
示例#17
0
 def _build_predictions(params, predict_ids):
     table = lookup_ops.index_to_string_table_from_file(
         vocabulary_file=params["target_vocab"], default_value="<blank>")
     predict_labels = table.lookup(keys=tf.cast(predict_ids, tf.int64))
     predictions = {
         "predict_ids": predict_ids,
         "predict_labels": predict_labels
     }
     return predictions
示例#18
0
def create_selfplay_model(model_creator,
                          is_mutable,
                          num_workers,
                          jobid,
                          hparams,
                          scope=None,
                          extra_args=None):
    """create slef play models."""
    graph = tf.Graph()
    with graph.as_default(), tf.container(scope or "selfplay"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)

        if is_mutable:
            mutable_index = 0
        else:
            mutable_index = 1

        # get a list of iterators and placeholders
        iterators, placeholders = self_play_iterator_creator(
            hparams, num_workers, jobid)
        train_iterator, self_play_fulltext_iterator, self_play_structured_iterator = iterators
        data_placeholder, kb_placeholder, batch_size_placeholder, skip_count_placeholder = placeholders

        # get an iterator handler
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_iterator.output_types, train_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)

        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=batched_iterator,
                                  handle=handle,
                                  mode=[
                                      dialogue_utils.mode_self_play_mutable,
                                      dialogue_utils.mode_self_play_immutable
                                  ][mutable_index],
                                  vocab_table=vocab_table,
                                  reverse_vocab_table=reverse_vocab_table,
                                  scope=scope,
                                  extra_args=extra_args)
    return SelfplayModel(graph=graph,
                         model=model,
                         placeholder_iterator=iterator,
                         placeholder_handle=handle,
                         train_iterator=train_iterator,
                         self_play_ft_iterator=self_play_fulltext_iterator,
                         self_play_st_iterator=self_play_structured_iterator,
                         data_placeholder=data_placeholder,
                         kb_placeholder=kb_placeholder,
                         skip_count_placeholder=skip_count_placeholder,
                         batch_size_placeholder=batch_size_placeholder)
示例#19
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       extra_args=None):
    """Create graph, model and iterator for training."""
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        data_dataset = tf.data.TextLineDataset(hparams.train_data)
        kb_dataset = tf.data.TextLineDataset(hparams.train_kb)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)
        # this is the actual train_iterator
        train_iterator = iterator_utils.get_iterator(
            data_dataset,
            kb_dataset,
            vocab_table,
            batch_size=hparams.batch_size,
            t1=hparams.t1,
            t2=hparams.t2,
            eod=hparams.eod,
            len_action=hparams.len_action,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            max_dialogue_len=hparams.max_dialogue_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        # this is the placeholder iterator. One can use this placeholder iterator
        # to switch between training and evauation.
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_iterator.output_types, train_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)
        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=batched_iterator,
                                  handle=handle,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  vocab_table=vocab_table,
                                  scope=scope,
                                  extra_args=extra_args,
                                  reverse_vocab_table=reverse_vocab_table)
    return TrainModel(graph=graph,
                      model=model,
                      placeholder_iterator=iterator,
                      train_iterator=train_iterator,
                      placeholder_handle=handle,
                      skip_count_placeholder=skip_count_placeholder)
示例#20
0
 def build_predictions(self, predict_ids, params):
     tags_idx2str = lookup_ops.index_to_string_table_from_file(
         params['tag_vocab'], default_value=params['oov_tag'])
     predict_tags = tags_idx2str.lookup(tf.cast(predict_ids, tf.int64))
     predictions = {
         "predict_ids": predict_ids,
         "predict_tags": predict_tags
     }
     return predictions
    def __init__(self,
                 corpus_dir,
                 hparams=None,
                 knbase_dir=None,
                 training=True,
                 augment_factor=3,
                 buffer_size=8192):
        """
        Args:
            corpus_dir: Name of the folder storing corpus files for training.
            hparams: The object containing the loaded hyper parameters. If None, it will be 
                    initialized here.
            knbase_dir: Name of the folder storing data files for the knowledge base. Used for 
                    inference only.
            training: Whether to use this object for training.
            augment_factor: Times the training data appears. If 1 or less, no augmentation.
            buffer_size: The buffer size used for mapping process during data processing.
        """
        if hparams is None:
            self.hparams = HParams(corpus_dir).hparams
        else:
            self.hparams = hparams

        self.src_max_len = self.hparams.src_max_len
        self.tgt_max_len = self.hparams.tgt_max_len

        self.training = training
        self.text_set = None
        self.id_set = None

        vocab_file = os.path.join(corpus_dir, VOCAB_FILE)
        self.vocab_size, _ = check_vocab(vocab_file)
        self.vocab_table = lookup_ops.index_table_from_file(
            vocab_file, default_value=self.hparams.unk_id)
        # print("vocab_size = {}".format(self.vocab_size))

        if training:
            self.case_table = prepare_case_table()
            self.reverse_vocab_table = None
            self._load_corpus(corpus_dir, augment_factor)
            self._convert_to_tokens(buffer_size)

            self.upper_words = {}
            self.stories = {}
            self.jokes = []
        else:
            self.case_table = None
            self.reverse_vocab_table = \
                lookup_ops.index_to_string_table_from_file(vocab_file,
                                                           default_value=self.hparams.unk_token)
            assert knbase_dir is not None
            knbs = KnowledgeBase()
            knbs.load_knbase(knbase_dir)
            self.upper_words = knbs.upper_words
            self.stories = knbs.stories
            self.jokes = knbs.jokes
示例#22
0
def create_reverse_vocab_tables(vocab_file):
    """Creates reverse vocab table for vocab_file."""
    vocab_tables = []
    for i in range(NUM_OUTPUTS_PER_TIMESTEP):
        vocab_fname = vocab_file.split('.')[0] + str(
            i) + "." + vocab_file.split('.')[1]
        src_vocab_table = lookup_ops.index_to_string_table_from_file(
            vocab_fname, default_value=UNK)
        vocab_tables.append(src_vocab_table)
    return vocab_tables
示例#23
0
  def test_index_to_string_table_with_vocab_size(self):
    vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
    with self.test_session():
      table = lookup_ops.index_to_string_table_from_file(
          vocabulary_file=vocabulary_file, vocab_size=3)
      features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))

      self.assertRaises(errors_impl.OpError, features.eval)
      lookup_ops.tables_initializer().run()
      self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
示例#24
0
def _convert_ids_to_strings(tgt_vocab_file, ids):
    """Convert prediction ids to words."""
    with tf.Session() as sess:
        reverse_target_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)
        sess.run(tf.tables_initializer())
        translations = sess.run(
            reverse_target_vocab_table.lookup(
                tf.to_int64(tf.convert_to_tensor(np.asarray(ids)))))
    return translations
示例#25
0
 def __init__(self, vocab_f_path, count_f_path=None):
     self._vocab_f_path = vocab_f_path
     self._count_f_path = count_f_path
     self._vocab_size = None
     self._vocab_counts = None
     with tf.variable_scope("vocab_lookup"):
         self._id2word_table = lookup_ops.index_to_string_table_from_file(
             self._vocab_f_path, default_value=UNK, name="id2word")
         self._word2id_table = lookup_ops.index_table_from_file(
             self._vocab_f_path, default_value=UNK_ID, name="word2id")
示例#26
0
def create_infer_model(hparams):
    graph = tf.Graph()
    with graph.as_default() as graph:
        src_vocab_table, tgt_vocab_table = vocab_table_util.get_vocab_table(
            hparams.src_vocab_file, hparams.tgt_vocab_file)
        reversed_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.tgt_vocab_file, default_value=vocab_table_util.UNK)
        reversed_src_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.src_vocab_file, default_value=vocab_table_util.UNK)

        with tf.variable_scope("NMTModel") as nmtmodel_scope:
            with tf.variable_scope("infer_iterator"):
                src_dataset_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
                iterator = iterator_utils.get_nmt_infer_iterator(
                    src_dataset_file, src_vocab_table, hparams.batch_size,
                    hparams.source_reverse, hparams.eos)
            with tf.variable_scope("shared_encoder") as encoder_scope:
                encoder = en.Encoder(hparams,
                                     tf.contrib.learn.ModeKeys.INFER,
                                     dtype=tf.float32,
                                     scope=encoder_scope)
            with tf.variable_scope("shared_decoder") as decoder_scope:
                decoder = de.Decoder(hparams,
                                     tf.contrib.learn.ModeKeys.INFER,
                                     dtype=tf.float32,
                                     scope=decoder_scope)
            nmt_model = mdl.NMTModel(
                hparams,
                src_vocab_table,
                tgt_vocab_table,
                encoder,
                decoder,
                iterator,
                tf.contrib.learn.ModeKeys.INFER,
                reversed_tgt_vocab_table=reversed_tgt_vocab_table,
                reversed_src_vocab_table=reversed_src_vocab_table)
        saver = tf.train.Saver(tf.global_variables())
    return InferModel(graph=graph,
                      model=nmt_model,
                      encoder=encoder,
                      decoder=decoder,
                      iterator=iterator,
                      saver=saver)
示例#27
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "infer"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)

        data_src_placeholder = tf.placeholder(shape=[None],
                                              dtype=tf.string,
                                              name="src_ph")
        kb_placeholder = tf.placeholder(shape=[None],
                                        dtype=tf.string,
                                        name="kb_ph")
        batch_size_placeholder = tf.placeholder(shape=[],
                                                dtype=tf.int64,
                                                name="bs_ph")

        data_src_dataset = tf.data.Dataset.from_tensor_slices(
            data_src_placeholder)
        kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder)

        # this is the actual infer iterator
        infer_iterator = iterator_utils.get_infer_iterator(
            data_src_dataset,
            kb_dataset,
            vocab_table,
            batch_size=batch_size_placeholder,
            eod=hparams.eod,
            len_action=hparams.len_action)

        # this is the placeholder infer iterator
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, infer_iterator.output_types, infer_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)

        model = model_creator(hparams,
                              iterator=batched_iterator,
                              handle=handle,
                              mode=tf.contrib.learn.ModeKeys.INFER,
                              vocab_table=vocab_table,
                              reverse_vocab_table=reverse_vocab_table,
                              scope=scope,
                              extra_args=extra_args)

    return InferModel(graph=graph,
                      model=model,
                      placeholder_iterator=iterator,
                      placeholder_handle=handle,
                      infer_iterator=infer_iterator,
                      data_src_placeholder=data_src_placeholder,
                      kb_placeholder=kb_placeholder,
                      batch_size_placeholder=batch_size_placeholder)
示例#28
0
    def _build_graph(self, hparams):
        source = self.input.source
        tgt = self.input.target
        # x: [batch_size, time_step, embedding_size], float32
        self.x = tf.nn.embedding_lookup(self.embedding, source)
        # y: [batch_size, time_step]
        self.y = tgt

        cell_forward = self._single_cell(hparams)
        cell_backward = self._single_cell(hparams)

        # time_major 可以适应输入维度。
        outputs, bi_state = tf.nn.bidirectional_dynamic_rnn(
            cell_forward,
            cell_backward,
            self.x,
            sequence_length=self.input.source_sequence_length,
            dtype=tf.float32)
        forward_state, backward_state = bi_state
        bi_state = tf.concat([forward_state.h, backward_state.h], axis=1)

        # projection:
        w = tf.get_variable("projection_w",
                            [2 * hparams.num_units, hparams.class_size])
        b = tf.get_variable("projection_b", [hparams.class_size])
        # x_reshape = tf.reshape(outputs, [-1, 2 * hparams.num_units], name="outputs_x_reshape")
        self.outputs = tf.matmul(bi_state, w) + b

        num_tags = hparams.class_size
        self.transition_params = tf.get_variable("transitions",
                                                 [num_tags, num_tags])
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            self.loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=self.outputs, labels=self.y))
            # Add a training op to tune the parameters.
            learning_rate = hparams.learning_rate
            global_step = self.global_step
            self.train_op = tf.train.GradientDescentOptimizer(
                learning_rate).minimize(self.loss, global_step=global_step)

            correct_prediction = tf.equal(tf.cast(self.y, dtype=tf.int64),
                                          tf.argmax(self.outputs, 1))
            self.accuracy = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32))
            self.train_summary = tf.summary.merge([
                tf.summary.scalar("accuracy", self.accuracy),
                tf.summary.scalar("train_loss", self.loss),
            ])
        else:
            self.projection = tf.nn.softmax(self.outputs)
            self.projection_id = tf.argmax(self.outputs, axis=1)
            tag_table = lookup_ops.index_to_string_table_from_file(
                hparams.tgt_vocab_file, default_value='<unk>')
            self.tag = tag_table.lookup(self.projection_id)
示例#29
0
    def __init__(self, hparams=None, training=True, buffer_size=8192):
        
        self.training = training
        self.hparams = hparams

        self.src_max_len = self.hparams.src_max_len
        self.tgt_max_len = self.hparams.tgt_max_len
        
        self.vocab_size, self.vocab_list = check_vocab(VOCAB_FILE)
        self.vocab_table = lookup_ops.index_table_from_file(VOCAB_FILE,
                                                            default_value=self.hparams.unk_id)
        
        if os.path.isfile(MODEL_FILE):
            print("# Load Word2Vec model")
            self.embedding_model = Word2Vec.load(MODEL_FILE)
        else:
            print("# Word2Vec model doesn't exist")
            self.embedding_model = self.create_embedding(CORPUS_DIR, self.hparams.embedding_size)
            print("# Save Word2Vec model")
            self.embedding_model.save(MODEL_FILE)
        
        if training:
            self.case_table = prepare_case_table()
#            self.reverse_vocab_table = None
            
            self.reverse_vocab_table = \
                lookup_ops.index_to_string_table_from_file(VOCAB_FILE,
                                                           default_value=self.hparams.unk_token)
                
            train_text_set = self._load_corpus(TRAIN_FILE)
            test_text_set = self._load_corpus(TEST_FILE)
            self.train_id_set = self._convert_to_tokens(buffer_size, train_text_set)
            self.test_id_set = self._convert_to_tokens(buffer_size, test_text_set)
            
#            self.text_set = self._load_corpus(CORPUS_DIR)
#            self.id_set = self._convert_to_tokens(buffer_size, self.text_set)
            
        else:
            self.case_table = None
            self.reverse_vocab_table = \
                lookup_ops.index_to_string_table_from_file(VOCAB_FILE,
                                                           default_value=self.hparams.unk_token)
示例#30
0
  def test_index_to_string_table_with_vocab_size_too_large(self):
    vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
    with self.test_session():
      table = lookup_ops.index_to_string_table_from_file(
          vocabulary_file=vocabulary_file, vocab_size=4)
      features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))

      self.assertRaises(errors_impl.OpError, features.eval)
      init = lookup_ops.tables_initializer()
      self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                              "Invalid vocab_size", init.run)
示例#31
0
文件: heads.py 项目: jgung/tf-nlp
 def _prediction(self):
     """
     Called after `_eval_predict` for prediction.
     """
     index_to_label = index_to_string_table_from_file(
         vocabulary_file=os.path.join(self.params.vocab_path, self.name),
         default_value=self.extractor.unknown_word)
     self.predictions = tf.identity(index_to_label.lookup(
         tf.cast(self.predictions, dtype=tf.int64)),
                                    name="labels")
     self.export_outputs = {self.name: self.predictions}
示例#32
0
    def _get_reverse_vocab_table(self):
        # 根据文件生成index-词汇映射
        assert self._src_vocab_file and self._tgt_vocab_file

        if tf.gfile.Exists(self._src_vocab_file):
            reverse_src_vocab_table = lookup_ops.index_to_string_table_from_file(
                self._src_vocab_file, default_value=self.UNK)
            if self.share_vocab:
                reverse_tgt_vocab_table = reverse_src_vocab_table
            else:
                if tf.gfile.Exists(self._tgt_vocab_file):
                    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
                        self._tgt_vocab_file, default_value=self.UNK)
                else:
                    raise ValueError("tgt_vocab_file '%s' does not exists" %
                                     self._tgt_vocab_file)
        else:
            raise ValueError("src_vocab_file '%s' does not exists" %
                             self._src_vocab_file)

        return reverse_src_vocab_table, reverse_tgt_vocab_table
示例#33
0
 def __init__(self, config_dict):
     super(Xletter2Seq, self).__init__()
     self.cfg = Xletter2SeqConfig(config_dict)
     self.op_helper = self.get_ophelper(self.cfg.input_mode,
                                        self.cfg.xletter_dict,
                                        self.cfg.xletter_win_size)
     self.decoder_dict = lookup_ops.index_table_from_file(
         self.cfg.vocab_path + "/" + self.cfg.decoder_vocab_file,
         default_value=self.cfg.dict_param.unk_id)
     self.reverse_decoder_dict = lookup_ops.index_to_string_table_from_file(
         self.cfg.vocab_path + "/" + self.cfg.decoder_vocab_file,
         default_value=self.cfg.dict_param.unk)
示例#34
0
def create_vocab_tables(src_vocab_file, tgt_vocab_file, config):
    src_vocab_table = lookup_ops.index_table_from_file(
        src_vocab_file, default_value=UNK_ID)
    if config.share_vocab:
        tgt_vocab_table = src_vocab_table
    else:
        tgt_vocab_table = lookup_ops.index_table_from_file(
            tgt_vocab_file, default_value=UNK_ID)

    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=config.unk)

    return src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table
示例#35
0
 def test_index_to_string_table_with_vocab_size_too_small(self):
   default_value = b"NONE"
   vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
   with self.test_session():
     table = lookup_ops.index_to_string_table_from_file(
         vocabulary_file=vocabulary_file,
         vocab_size=2,
         default_value=default_value)
     features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
     self.assertRaises(errors_impl.OpError, features.eval)
     lookup_ops.tables_initializer().run()
     self.assertAllEqual((b"salad", default_value, default_value),
                         features.eval())
示例#36
0
def create_inference_model(model_creator, hparams, scope=None, extra_args=None):
    """Create Inference model"""
    
    graph = tf.Graph()
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    
    with graph.as_default(), tf.container(scope or "infer"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
                src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
                tgt_vocab_file, default_value=vocab_utils.UNK)
        
        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
        
        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)
        iterator = iterator_utils.get_infer_iterator(
                src_dataset,
                src_vocab_table,
                batch_size=batch_size_placeholder,
                eos=hparams.eos,
                src_max_len=hparams.src_max_len_infer)
        
        model = model_creator(
                hparams,
                iterator=iterator,
                mode=tf.contrib.learn.ModeKeys.INFER,
                source_vocab_table=src_vocab_table,
                target_vocab_table=tgt_vocab_table,
                reverse_target_vocab_table=reverse_tgt_vocab_table,
                scope=scope, extra_args=extra_args)
        
    return InferModel(graph=graph, model=model,
                      src_placeholder=src_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      iterator=iterator)
示例#37
0
文件: utils.py 项目: luluyouyue/NER



'''
    以下是做测试用的,不用管。
'''
if __name__ == '__main__':
    #################### Just for testing #########################
    vocab_size = get_src_vocab_size()
    src_unknown_id = tgt_unknown_id = vocab_size
    src_padding = vocab_size + 1

    src_vocab_table, tgt_vocab_table = create_vocab_tables(src_vocab_file, tgt_vocab_file, src_unknown_id, tgt_unknown_id)
    # iterator = get_iterator(src_vocab_table, tgt_vocab_table, vocab_size, 100, random_seed=None)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        src_vocab_file, default_value='<tag-unknown>')

    iterator = get_predict_iterator(src_vocab_table, vocab_size, 1)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(iterator.initializer)
        tf.tables_initializer().run()

        # 根据ID查字。
        word = reverse_tgt_vocab_table.lookup(tf.constant(12001, dtype=tf.int64))
        print sess.run(word)
        for i in range(10):
            try:
                # source, target = sess.run([iterator.source, iterator.target_input])
                source = sess.run(iterator.source)
示例#38
0
文件: utils.py 项目: luluyouyue/NER
def tag_to_id_table():
    return lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value='<tag-unknown>')