Beispiel #1
0
    def _test_pipeline(self, mode, params=None):
        """Helper function to test the full model pipeline.
    """
        # Create source and target example
        source_len = self.sequence_length + 5
        target_len = self.sequence_length + 10
        source = " ".join(np.random.choice(self.vocab_list, source_len))
        target = " ".join(np.random.choice(self.vocab_list, target_len))
        sources_file, targets_file = test_utils.create_temp_parallel_data(
            sources=[source], targets=[target])

        # Build model graph
        model = self.create_model(mode, params)
        input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={
            "source_files": [sources_file.name],
            "target_files": [targets_file.name]
        },
                                                                   mode=mode)
        input_fn = training_utils.create_input_fn(pipeline=input_pipeline_,
                                                  batch_size=self.batch_size)
        features, labels = input_fn()
        fetches = model(features, labels, None)
        fetches = [_ for _ in fetches if _ is not None]

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())
            with tf.contrib.slim.queues.QueueRunners(sess):
                fetches_ = sess.run(fetches)

        sources_file.close()
        targets_file.close()

        return model, fetches_
    def test_pipeline(self):
        file_source, file_target = test_utils.create_temp_parallel_data(
            sources=["Hello World . 笑"], targets=["Bye 泣"])

        pipeline = input_pipeline.ParallelTextInputPipeline(
            params={
                "source_files": [file_source.name],
                "target_files": [file_target.name],
                "num_epochs": 5,
                "shuffle": False
            },
            mode=tf.contrib.learn.ModeKeys.TRAIN)

        data_provider = pipeline.make_data_provider()

        features = pipeline.read_from_data_provider(data_provider)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            with tf.contrib.slim.queues.QueueRunners(sess):
                res = sess.run(features)

        self.assertEqual(res["source_len"], 5)
        self.assertEqual(res["target_len"], 4)
        np.testing.assert_array_equal(
            np.char.decode(res["source_tokens"].astype("S"), "utf-8"),
            ["Hello", "World", ".", "笑", "SEQUENCE_END"])
        np.testing.assert_array_equal(
            np.char.decode(res["target_tokens"].astype("S"), "utf-8"),
            ["SEQUENCE_START", "Bye", "泣", "SEQUENCE_END"])
Beispiel #3
0
  def _test_with_args(self, **kwargs):
    """Helper function to test create_input_fn with keyword arguments"""
    sources_file, targets_file = test_utils.create_temp_parallel_data(
        sources=["Hello World ."], targets=["Goodbye ."])

    pipeline = input_pipeline.ParallelTextInputPipeline(
        params={
            "source_files": [sources_file.name],
            "target_files": [targets_file.name]
        },
        mode=tf.contrib.learn.ModeKeys.TRAIN)
    input_fn = training_utils.create_input_fn(pipeline=pipeline, **kwargs)
    features, labels = input_fn()

    with self.test_session() as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        features_, labels_ = sess.run([features, labels])

    self.assertEqual(
        set(features_.keys()), set(["source_tokens", "source_len"]))
    self.assertEqual(set(labels_.keys()), set(["target_tokens", "target_len"]))
Beispiel #4
0
def test_model(source_path, target_path, vocab_path):

    tf.logging.set_verbosity(tf.logging.INFO)
    batch_size = 2

    # Build model graph
    mode = tf.contrib.learn.ModeKeys.TRAIN
    params_ = AttentionSeq2Seq.default_params().copy()
    params_.update({
        "vocab_source": vocab_path,
        "vocab_target": vocab_path,
    })
    model = AttentionSeq2Seq(params=params_, mode=mode)

    tf.logging.info(vocab_path)

    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={
        "source_files": [source_path],
        "target_files": [target_path]
    },
                                                               mode=mode)
    input_fn = training_utils.create_input_fn(pipeline=input_pipeline_,
                                              batch_size=batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)

    fetches = [_ for _ in fetches if _ is not None]

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())
        with tf.contrib.slim.queues.QueueRunners(sess):
            fetches_ = sess.run(fetches)

    return model, fetches_
Beispiel #5
0
def test_copy_gen_model(source_path=None, target_path=None, vocab_path=None):

    tf.logging.set_verbosity(tf.logging.INFO)
    batch_size = 2
    input_depth = 4
    sequence_length = 10

    if vocab_path is None:
        # Create vocabulary
        vocab_list = [str(_) for _ in range(10)]
        vocab_list += ["笑う", "泣く", "了解", "はい", "^_^"]
        vocab_size = len(vocab_list)
        vocab_file = test_utils.create_temporary_vocab_file(vocab_list)
        vocab_info = vocab.get_vocab_info(vocab_file.name)
        vocab_path = vocab_file.name
        tf.logging.info(vocab_file.name)
    else:
        vocab_info = vocab.get_vocab_info(vocab_path)
        vocab_list = get_vocab_list(vocab_path)

    extend_vocab = vocab_list + ["中国", "爱", "你"]

    tf.contrib.framework.get_or_create_global_step()
    source_len = sequence_length + 5
    target_len = sequence_length + 10
    source = " ".join(np.random.choice(extend_vocab, source_len))
    target = " ".join(np.random.choice(extend_vocab, target_len))

    is_tmp_file = False
    if source_path is None and target_path is None:
        is_tmp_file = True
        sources_file, targets_file = test_utils.create_temp_parallel_data(
            sources=[source], targets=[target])
        source_path = sources_file.name
        target_path = targets_file.name

    # Build model graph
    mode = tf.contrib.learn.ModeKeys.TRAIN
    params_ = CopyGenSeq2Seq.default_params().copy()
    params_.update({
        "vocab_source": vocab_path,
        "vocab_target": vocab_path,
    })
    model = CopyGenSeq2Seq(params=params_, mode=mode)

    tf.logging.info(source_path)
    tf.logging.info(target_path)

    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={
        "source_files": [source_path],
        "target_files": [target_path]
    },
                                                               mode=mode)
    input_fn = training_utils.create_input_fn(pipeline=input_pipeline_,
                                              batch_size=batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())
        with tf.contrib.slim.queues.QueueRunners(sess):
            fetches_ = sess.run(fetches)

    if is_tmp_file:
        sources_file.close()
        targets_file.close()

    return model, fetches_