コード例 #1
0
  def testGetInferIterator(self):
    src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c c a", "c a", "d", "f e a g"]))
    hparams = tf.contrib.training.HParams(
        random_seed=3,
        eos="eos",
        sos="sos")
    batch_size = 2
    dataset = iterator_utils.get_infer_iterator(
        src_dataset=src_dataset,
        src_vocab_table=src_vocab_table,
        batch_size=batch_size,
        eos=hparams.eos)
    table_initializer = tf.tables_initializer()
    iterator = dataset.make_initializable_iterator()
    get_next = iterator.get_next()
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer)
      features = sess.run(get_next)

      self.assertAllEqual(
          [
              [2, 2, 0],  # c c a
              [2, 0, 3]
          ],  # c a eos
          features["source"])
      self.assertAllEqual([3, 2], features["source_sequence_length"])
コード例 #2
0
    def _input_fn(params):
        """Input function."""
        del params

        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
            tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
        else:
            if hparams.mode == "translate":
                src_file = hparams.translate_file + ".tok"
                tgt_file = hparams.translate_file + ".tok"
            else:
                src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
                tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
        src_vocab_file = hparams.src_vocab_file
        tgt_vocab_file = hparams.tgt_vocab_file
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)

        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            # Run one epoch and stop if running train_and_eval.
            if hparams.mode == "train_and_eval":
                # In this mode input pipeline is restarted every epoch, so choose a
                # different random_seed.
                num_repeat = 1
                random_seed = hparams.random_seed + int(time.time()) % 100
            else:
                num_repeat = 8
                random_seed = hparams.random_seed
            return iterator_utils.get_iterator(
                src_dataset,
                tgt_dataset,
                src_vocab_table,
                tgt_vocab_table,
                batch_size=hparams.batch_size,
                sos=hparams.sos,
                eos=hparams.eos,
                random_seed=random_seed,
                num_buckets=hparams.num_buckets,
                src_max_len=hparams.src_max_len,
                tgt_max_len=hparams.tgt_max_len,
                output_buffer_size=None,
                skip_count=None,
                num_shards=1,  # flags.num_workers
                shard_index=0,  # flags.jobid
                reshuffle_each_iteration=True,
                use_char_encode=hparams.use_char_encode,
                num_repeat=num_repeat,
                filter_oversized_sequences=True
            )  # need to update get_effective_train_epoch_size() if this flag flips.
        else:
            return iterator_utils.get_infer_iterator(
                src_dataset,
                src_vocab_table,
                batch_size=hparams.infer_batch_size,
                eos=hparams.eos,
                src_max_len=hparams.src_max_len,
                use_char_encode=hparams.use_char_encode)
コード例 #3
0
    def _input_fn(params):
        """Input function."""
        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
            tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
        else:
            src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
            tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
        src_vocab_file = hparams.src_vocab_file
        tgt_vocab_file = hparams.tgt_vocab_file
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)

        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            if "context" in params:
                batch_size = params["batch_size"]
                num_hosts = params["context"].num_hosts
                # TODO(dehao): update to use current_host once available in API.
                current_host = params["context"].current_input_fn_deployment(
                )[1]
            else:
                num_hosts = 1
                current_host = 0
                batch_size = hparams.batch_size
            mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE,
                                  value=batch_size)
            mlperf_log.gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN,
                                  value=hparams.src_max_len)
            return iterator_utils.get_iterator(
                src_dataset,
                tgt_dataset,
                src_vocab_table,
                tgt_vocab_table,
                batch_size=batch_size,
                sos=hparams.sos,
                eos=hparams.eos,
                random_seed=hparams.random_seed,
                num_buckets=hparams.num_buckets,
                src_max_len=hparams.src_max_len,
                tgt_max_len=hparams.tgt_max_len,
                output_buffer_size=None,
                skip_count=None,
                num_shards=num_hosts,
                shard_index=current_host,
                reshuffle_each_iteration=True,
                use_char_encode=hparams.use_char_encode,
                filter_oversized_sequences=True)
        else:
            return iterator_utils.get_infer_iterator(
                src_dataset,
                src_vocab_table,
                batch_size=hparams.infer_batch_size,
                eos=hparams.eos,
                src_max_len=hparams.src_max_len,
                use_char_encode=hparams.use_char_encode)
コード例 #4
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "infer"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)

        data_src_placeholder = tf.placeholder(shape=[None],
                                              dtype=tf.string,
                                              name="src_ph")
        kb_placeholder = tf.placeholder(shape=[None],
                                        dtype=tf.string,
                                        name="kb_ph")
        batch_size_placeholder = tf.placeholder(shape=[],
                                                dtype=tf.int64,
                                                name="bs_ph")

        data_src_dataset = tf.data.Dataset.from_tensor_slices(
            data_src_placeholder)
        kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder)

        # this is the actual infer iterator
        infer_iterator = iterator_utils.get_infer_iterator(
            data_src_dataset,
            kb_dataset,
            vocab_table,
            batch_size=batch_size_placeholder,
            eod=hparams.eod,
            len_action=hparams.len_action)

        # this is the placeholder infer iterator
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, infer_iterator.output_types, infer_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)

        model = model_creator(hparams,
                              iterator=batched_iterator,
                              handle=handle,
                              mode=tf.contrib.learn.ModeKeys.INFER,
                              vocab_table=vocab_table,
                              reverse_vocab_table=reverse_vocab_table,
                              scope=scope,
                              extra_args=extra_args)

    return InferModel(graph=graph,
                      model=model,
                      placeholder_iterator=iterator,
                      placeholder_handle=handle,
                      infer_iterator=infer_iterator,
                      data_src_placeholder=data_src_placeholder,
                      kb_placeholder=kb_placeholder,
                      batch_size_placeholder=batch_size_placeholder)
コード例 #5
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    speaker_file = hparams.speaker_file

    with graph.as_default(), tf.container(scope or "infer"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)

        speaker_table, _ = vocab_utils.create_vocab_tables(
            speaker_file, speaker_file, True)
        reverse_speaker_table = lookup_ops.index_to_string_table_from_file(
            speaker_file, default_value=vocab_utils.UNK)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        src_spkr_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        tgt_spkr_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)
        src_spkr_dataset = tf.data.Dataset.from_tensor_slices(
            src_spkr_placeholder)
        tgt_spkr_dataset = tf.data.Dataset.from_tensor_slices(
            tgt_spkr_placeholder)
        iterator = iterator_utils.get_infer_iterator(
            src_dataset,
            src_vocab_table,
            src_spkr_dataset,
            tgt_spkr_dataset,
            speaker_table,
            batch_size=batch_size_placeholder,
            eos=hparams.eos,
            src_max_len=hparams.src_max_len_infer)
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.INFER,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            speaker_table=speaker_table,
            reverse_speaker_table=reverse_speaker_table,
            scope=scope,
            extra_args=extra_args)
    return InferModel(graph=graph,
                      model=model,
                      src_placeholder=src_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      src_speaker_placeholder=src_spkr_placeholder,
                      tgt_speaker_placeholder=tgt_spkr_placeholder,
                      iterator=iterator)
コード例 #6
0
    def testGetInferIterator(self):
        src_vocab_table = lookup_ops.index_table_from_tensor(
            tf.constant(["a", "b", "c", "eos", "sos"]))
        src_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["c c a", "c a", "d", "f e a g"]))
        hparams = tf.contrib.training.HParams(random_seed=3,
                                              eos="eos",
                                              sos="sos")
        batch_size = 2
        src_max_len = 3
        iterator = iterator_utils.get_infer_iterator(
            src_dataset=src_dataset,
            src_vocab_table=src_vocab_table,
            batch_size=batch_size,
            eos=hparams.eos,
            src_max_len=src_max_len)
        table_initializer = tf.tables_initializer()
        source = iterator.source
        seq_len = iterator.source_sequence_length
        self.assertEqual([None, None], source.shape.as_list())
        self.assertEqual([None], seq_len.shape.as_list())
        with self.test_session() as sess:
            sess.run(table_initializer)
            sess.run(iterator.initializer)

            (source_v, seq_len_v) = sess.run((source, seq_len))
            self.assertAllEqual(
                [
                    [2, 2, 0],  # c c a
                    [2, 0, 3]
                ],  # c a eos
                source_v)
            self.assertAllEqual([3, 2], seq_len_v)

            (source_v, seq_len_v) = sess.run((source, seq_len))
            self.assertAllEqual(
                [
                    [-1, 3, 3],  # "d" == unknown, eos eos
                    [-1, -1, 0]
                ],  # "f" == unknown, "e" == unknown, a
                source_v)
            self.assertAllEqual([1, 3], seq_len_v)

            with self.assertRaisesOpError("End of sequence"):
                sess.run((source, seq_len))
コード例 #7
0
def mytest_infer_interator():
    src_dataset = tf.data.TextLineDataset(hparam.train_src)
    myinput = get_infer_iterator(src_dataset, hparam)
    ss = myinput.reverse_table.lookup(myinput.src)
    with tf.Session() as sess:
        sess.run(tf.tables_initializer())
        sess.run(myinput.initializer)
        for i in range(5):
            try:
                _src, _src_seq_len, cc = sess.run(
                    [myinput.src, myinput.src_seq_len] + [ss])
                print('src', _src)

                print('src_seq_len', _src_seq_len)
                print('reverce')
                for i, c in enumerate(cc):
                    print(get_translation(cc, i, hparam.EOS))
            except tf.errors.OutOfRangeError:
                print('xxxxxxxxxxxxxxx')
                sess.run(myinput.initializer)
コード例 #8
0
def createInferModel(hparam, modelFunc=None, src_path=None, tgt_path=None):
    def _get_config_proto():
        conf = tf.ConfigProto(allow_soft_placement=True,
                              log_device_placement=False)
        return conf

    graph = tf.Graph()
    with graph.as_default():
        if src_path:
            src_dataset = tf.data.TextLineDataset(src_path)
        else:
            src_dataset = tf.data.TextLineDataset(hparam.test_src)
        if tgt_path:
            ref_file = tgt_path
        else:
            ref_file = hparam.test_tgt

        batch_input = get_infer_iterator(src_dataset, hparam)
        sess = tf.Session(config=_get_config_proto())
        model = modelFunc(batch_input, 'infer', hparam)

        return InferModel(model, sess, graph, batch_input, ref_file)
コード例 #9
0
ファイル: common_test_utils.py プロジェクト: gplays/slackBot
def create_test_iterator(hparams, mode):
    """Create test iterator."""
    src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant([hparams.eos, "a", "b", "c", "d"]))
    tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"])
    tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping)
    if mode == tf.contrib.learn.ModeKeys.INFER:
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor(
            tgt_vocab_mapping)

    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["a a b b c", "a b b"]))

    if mode != tf.contrib.learn.ModeKeys.INFER:
        tgt_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["a b c b c", "a b c b"]))
        return (
            # TODO changeto accomodate new inputs
            iterator_utils.get_iterator(src_dataset=src_dataset,
                                        tgt_dataset=tgt_dataset,
                                        src_vocab_table=src_vocab_table,
                                        tgt_vocab_table=tgt_vocab_table,
                                        batch_size=hparams.batch_size,
                                        sos=hparams.sos,
                                        eos=hparams.eos,
                                        random_seed=hparams.random_seed,
                                        num_buckets=hparams.num_buckets),
            src_vocab_table,
            tgt_vocab_table)
    else:
        return (
            # TODO changeto accomodate new inputs
            iterator_utils.get_infer_iterator(src_dataset=src_dataset,
                                              src_vocab_table=src_vocab_table,
                                              eos=hparams.eos,
                                              batch_size=hparams.batch_size),
            src_vocab_table,
            tgt_vocab_table,
            reverse_tgt_vocab_table)
コード例 #10
0
def create_infer_model(model_creator,
                       hparams,
                       scope=None,
                       single_cell_fn=None):
    """Create inference model."""
    graph = tf.Graph()

    tgt_vocab_file = hparams.tgt_vocab_file

    with graph.as_default():

        tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file)
        # 转换成反向表
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)

        src_dataset = tf.contrib.data.Dataset.from_tensor_slices(
            src_placeholder)
        iterator = iterator_utils.get_infer_iterator(
            src_dataset,
            source_reverse=hparams.source_reverse,
            src_max_len=hparams.src_max_len_infer)

        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.INFER,
            target_vocab_table=tgt_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            scope=scope,
            single_cell_fn=single_cell_fn)

    return InferModel(graph=graph,
                      model=model,
                      src_placeholder=src_placeholder,
                      iterator=iterator)
コード例 #11
0
def create_infer_model(model_creator, hparams):
    """Create inference model."""
    graph = tf.Graph()
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    with graph.as_default(), tf.container("infer"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)

        iterator = iterator_utils.get_infer_iterator(
            src_dataset,
            src_vocab_table,
            batch_size=batch_size_placeholder,
            eos=hparams.eos,
            src_max_len=hparams.src_max_len_infer)

        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.estimator.ModeKeys.PREDICT,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table)

    return InferModel(graph=graph,
                      model=model,
                      src_placeholder=src_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      iterator=iterator)
コード例 #12
0
def self_play_iterator_creator(hparams, num_workers, jobid):
    """create a self play iterator. There are iterators that will be created here.
  A supervised training iterator used for supervised learning. A full text
  iterator and structured iterator used for reinforcement learning self play.
  Full text iterators feeds data from text files while structured iterators
  are initialized directly from objects. The former one is used for traiing.
  The later one is used for self play dialogue generation to eliminate the
  need of serializing them into actual text
  files.
  """
    vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
    data_dataset = tf.data.TextLineDataset(hparams.train_data)
    kb_dataset = tf.data.TextLineDataset(hparams.train_kb)
    skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
    # this is the actual iterator for supervised training
    train_iterator = iterator_utils.get_iterator(
        data_dataset,
        kb_dataset,
        vocab_table,
        batch_size=hparams.batch_size,
        t1=hparams.t1,
        t2=hparams.t2,
        eod=hparams.eod,
        len_action=hparams.len_action,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        max_dialogue_len=hparams.max_dialogue_len,
        skip_count=skip_count_placeholder,
        num_shards=num_workers,
        shard_index=jobid)

    # this is the actual iterator for self_play_fulltext_iterator
    data_placeholder = tf.placeholder(shape=[None],
                                      dtype=tf.string,
                                      name="src_ph")
    kb_placeholder = tf.placeholder(shape=[None],
                                    dtype=tf.string,
                                    name="kb_ph")
    batch_size_placeholder = tf.placeholder(shape=[],
                                            dtype=tf.int64,
                                            name="bs_ph")

    dataset_data = tf.data.Dataset.from_tensor_slices(data_placeholder)
    kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder)

    self_play_fulltext_iterator = iterator_utils.get_infer_iterator(
        dataset_data,
        kb_dataset,
        vocab_table,
        batch_size=batch_size_placeholder,
        eod=hparams.eod,
        len_action=hparams.len_action,
        self_play=True)

    # this is the actual iterator for self_play_structured_iterator
    self_play_structured_iterator = tf.data.Iterator.from_structure(
        self_play_fulltext_iterator.output_types,
        self_play_fulltext_iterator.output_shapes)
    iterators = [
        train_iterator, self_play_fulltext_iterator,
        self_play_structured_iterator
    ]

    # this is the list of placeholders
    placeholders = [
        data_placeholder, kb_placeholder, batch_size_placeholder,
        skip_count_placeholder
    ]
    return iterators, placeholders
コード例 #13
0
    def _input_fn(params):
        """Input function."""
        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
            tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
        else:
            src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
            tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
        src_vocab_file = hparams.src_vocab_file
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file)

        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            if "context" in params:
                batch_size = params["batch_size"]
                global_batch_size = batch_size
                num_hosts = params["context"].num_hosts
                # TODO(dehao): update to use current_host once available in API.
                current_host = params["context"].current_input_fn_deployment(
                )[1]
            else:
                if "dataset_index" in params:
                    current_host = params["dataset_index"]
                    num_hosts = params["dataset_num_shards"]
                    batch_size = params["batch_size"]
                    global_batch_size = hparams.batch_size
                else:
                    num_hosts = 1
                    current_host = 0
                    batch_size = hparams.batch_size
                    global_batch_size = batch_size
            if not hparams.use_preprocessed_data:
                src_dataset = tf.data.TextLineDataset(src_file)
                tgt_dataset = tf.data.TextLineDataset(tgt_file)
                return iterator_utils.get_iterator(
                    src_dataset,
                    tgt_dataset,
                    src_vocab_table,
                    tgt_vocab_table,
                    batch_size=batch_size,
                    global_batch_size=global_batch_size,
                    sos=hparams.sos,
                    eos=hparams.eos,
                    random_seed=hparams.random_seed,
                    num_buckets=hparams.num_buckets,
                    src_max_len=hparams.src_max_len,
                    tgt_max_len=hparams.tgt_max_len,
                    output_buffer_size=None,
                    skip_count=None,
                    num_shards=num_hosts,
                    shard_index=current_host,
                    reshuffle_each_iteration=True,
                    filter_oversized_sequences=True)
            else:
                return iterator_utils.get_preprocessed_iterator(
                    hparams.train_prefix + "*",
                    batch_size=batch_size,
                    random_seed=hparams.random_seed,
                    max_seq_len=hparams.src_max_len,
                    num_buckets=hparams.num_buckets,
                    shard_index=current_host,
                    num_shards=num_hosts)
        else:
            if "dataset_index" in params:
                current_host = params["dataset_index"]
                num_hosts = params["dataset_num_shards"]
            else:
                num_hosts = 1
                current_host = 0
            if "infer_batch_size" in params:
                batch_size = params["infer_batch_size"]
            else:
                batch_size = hparams.infer_batch_size
            src_dataset = tf.data.TextLineDataset(src_file)
            src_dataset = src_dataset.repeat().batch(
                hparams.infer_batch_size // num_hosts).shard(
                    num_hosts, current_host).apply(tf.contrib.data.unbatch())
            return iterator_utils.get_infer_iterator(
                src_dataset,
                src_vocab_table,
                batch_size=batch_size,
                eos=hparams.eos,
                sos=hparams.sos,
                src_max_len=hparams.src_max_len_infer)