def test_pipeline(self):
        file_source, file_target = test_utils.create_temp_parallel_data(
            sources=["Hello World . 笑"], targets=["Bye 泣"])

        pipeline = input_pipeline.ParallelTextInputPipeline(
            params={
                "source_files": [file_source.name],
                "target_files": [file_target.name],
                "num_epochs": 5,
                "shuffle": False
            },
            mode=tf.contrib.learn.ModeKeys.TRAIN)

        data_provider = pipeline.make_data_provider()

        features = pipeline.read_from_data_provider(data_provider)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            with tf.contrib.slim.queues.QueueRunners(sess):
                res = sess.run(features)

        self.assertEqual(res["source_len"], 5)
        self.assertEqual(res["target_len"], 4)
        np.testing.assert_array_equal(
            np.char.decode(res["source_tokens"].astype("S"), "utf-8"),
            ["Hello", "World", ".", "笑", "SEQUENCE_END"])
        np.testing.assert_array_equal(
            np.char.decode(res["target_tokens"].astype("S"), "utf-8"),
            ["SEQUENCE_START", "Bye", "泣", "SEQUENCE_END"])
예제 #2
0
    def _test_pipeline(self, mode, params=None):
        """Helper function to test the full model pipeline.
    """
        # Create source and target example
        source_len = 10
        target_len = self.max_decode_length + 10
        source = " ".join(np.random.choice(self.vocab_list, source_len))
        target = " ".join(np.random.choice(self.vocab_list, target_len))
        sources_file, targets_file = test_utils.create_temp_parallel_data(
            sources=[source], targets=[target])

        # Build model graph
        model = self.create_model(params)
        data_provider = lambda: data_utils.make_parallel_data_provider(
            [sources_file.name], [targets_file.name])
        input_fn = training_utils.create_input_fn(data_provider,
                                                  self.batch_size)
        features, labels = input_fn()
        fetches = model(features, labels, None, mode)
        fetches = [_ for _ in fetches if _ is not None]

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())
            with tf.contrib.slim.queues.QueueRunners(sess):
                fetches_ = sess.run(fetches)

        sources_file.close()
        targets_file.close()

        return model, fetches_
예제 #3
0
  def test_pipeline(self):
    file_source, file_target = test_utils.create_temp_parallel_data(
        sources=["Hello World . 笑"], targets=["Bye 泣"])

    pipeline = input_pipeline.ParallelTextInputPipeline(
        params={
            "source_files": [file_source.name],
            "target_files": [file_target.name],
            "num_epochs": 5,
            "shuffle": False
        },
        mode=tf.contrib.learn.ModeKeys.TRAIN)

    data_provider = pipeline.make_data_provider()

    features = pipeline.read_from_data_provider(data_provider)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      with tf.contrib.slim.queues.QueueRunners(sess):
        res = sess.run(features)

    self.assertEqual(res["source_len"], 5)
    self.assertEqual(res["target_len"], 4)
    np.testing.assert_array_equal(
        np.char.decode(res["source_tokens"].astype("S"), "utf-8"),
        ["Hello", "World", ".", "笑", "SEQUENCE_END"])
    np.testing.assert_array_equal(
        np.char.decode(res["target_tokens"].astype("S"), "utf-8"),
        ["SEQUENCE_START", "Bye", "泣", "SEQUENCE_END"])
    def _test_with_args(self, **kwargs):
        """Helper function to test create_input_fn with keyword arguments"""
        sources_file, targets_file = test_utils.create_temp_parallel_data(
            sources=["Hello World ."], targets=["Goodbye ."])
        data_provider_fn = lambda: data_utils.make_parallel_data_provider(
            [sources_file.name], [targets_file.name])
        input_fn = training_utils.create_input_fn(
            data_provider_fn=data_provider_fn, **kwargs)
        features, labels = input_fn()

        with self.test_session() as sess:
            with tf.contrib.slim.queues.QueueRunners(sess):
                features_, labels_ = sess.run([features, labels])

        self.assertEqual(set(features_.keys()),
                         set(["source_tokens", "source_len"]))
        self.assertEqual(set(labels_.keys()),
                         set(["target_tokens", "target_len"]))
예제 #5
0
  def _test_with_args(self, **kwargs):
    """Helper function to test create_input_fn with keyword arguments"""
    sources_file, targets_file = test_utils.create_temp_parallel_data(
        sources=["Hello World ."], targets=["Goodbye ."])

    pipeline = input_pipeline.ParallelTextInputPipeline(
        params={
            "source_files": [sources_file.name],
            "target_files": [targets_file.name]
        },
        mode=tf.contrib.learn.ModeKeys.TRAIN)
    input_fn = training_utils.create_input_fn(pipeline=pipeline, **kwargs)
    features, labels = input_fn()

    with self.test_session() as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        features_, labels_ = sess.run([features, labels])

    self.assertEqual(
        set(features_.keys()), set(["source_tokens", "source_len"]))
    self.assertEqual(set(labels_.keys()), set(["target_tokens", "target_len"]))
예제 #6
0
  def _test_pipeline(self, mode, params=None):
    """Helper function to test the full model pipeline.
    """
    # Create source and target example
    source_len = self.sequence_length + 5
    target_len = self.sequence_length + 10
    source = " ".join(np.random.choice(self.vocab_list, source_len))
    target = " ".join(np.random.choice(self.vocab_list, target_len))
    sources_file, targets_file = test_utils.create_temp_parallel_data(
        sources=[source], targets=[target])

    # Build model graph
    model = self.create_model(mode, params)
    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(
        params={
            "source_files": [sources_file.name],
            "target_files": [targets_file.name]
        },
        mode=mode)
    input_fn = training_utils.create_input_fn(
        pipeline=input_pipeline_, batch_size=self.batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())
      with tf.contrib.slim.queues.QueueRunners(sess):
        fetches_ = sess.run(fetches)

    sources_file.close()
    targets_file.close()

    return model, fetches_
예제 #7
0
    def test_read_from_data_provider(self):
        file_source, file_target = test_utils.create_temp_parallel_data(
            sources=["Hello World . 笑"], targets=["Bye 泣"])
        data_provider = data_utils.make_parallel_data_provider(
            data_sources_source=[file_source.name],
            data_sources_target=[file_target.name],
            num_epochs=5,
            shuffle=False)
        features = data_utils.read_from_data_provider(data_provider)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            with tf.contrib.slim.queues.QueueRunners(sess):
                res = sess.run(features)

        self.assertEqual(res["source_len"], 5)
        self.assertEqual(res["target_len"], 4)
        np.testing.assert_array_equal(
            np.char.decode(res["source_tokens"].astype("S"), "utf-8"),
            ["Hello", "World", ".", "笑", "SEQUENCE_END"])
        np.testing.assert_array_equal(
            np.char.decode(res["target_tokens"].astype("S"), "utf-8"),
            ["SEQUENCE_START", "Bye", "泣", "SEQUENCE_END"])
예제 #8
0
  def test_train_infer(self):
    """Tests training and inference scripts.
    """
    # Create dummy data
    sources_train, targets_train = test_utils.create_temp_parallel_data(
        sources=["a a a a", "b b b b", "c c c c", "笑 笑 笑 笑"],
        targets=["b b b b", "a a a a", "c c c c", "泣 泣 泣 泣"])
    sources_dev, targets_dev = test_utils.create_temp_parallel_data(
        sources=["a a", "b b", "c c c", "笑 笑 笑"],
        targets=["b b", "a a", "c c c", "泣 泣 泣"])
    vocab_source = test_utils.create_temporary_vocab_file(["a", "b", "c", "笑"])
    vocab_target = test_utils.create_temporary_vocab_file(["a", "b", "c", "泣"])

    _clear_flags()
    tf.reset_default_graph()
    train_script = imp.load_source("seq2seq.test.train_bin",
                                   os.path.join(BIN_FOLDER, "train.py"))

    # Set training flags
    tf.app.flags.FLAGS.output_dir = self.output_dir
    tf.app.flags.FLAGS.hooks = """
      - class: PrintModelAnalysisHook
      - class: MetadataCaptureHook
      - class: TrainSampleHook
    """
    tf.app.flags.FLAGS.metrics = """
      - class: LogPerplexityMetricSpec
      - class: BleuMetricSpec
      - class: RougeMetricSpec
        params:
          rouge_type: rouge_1/f_score
    """
    tf.app.flags.FLAGS.model = "AttentionSeq2Seq"
    tf.app.flags.FLAGS.model_params = """
    attention.params:
      num_units: 10
    vocab_source: {}
    vocab_target: {}
    """.format(vocab_source.name, vocab_target.name)
    tf.app.flags.FLAGS.batch_size = 2

    # We pass a few flags via a config file
    config_path = os.path.join(self.output_dir, "train_config.yml")
    with gfile.GFile(config_path, "w") as config_file:
      yaml.dump({
          "input_pipeline_train": {
              "class": "ParallelTextInputPipeline",
              "params": {
                  "source_files": [sources_train.name],
                  "target_files": [targets_train.name],
              }
          },
          "input_pipeline_dev": {
              "class": "ParallelTextInputPipeline",
              "params": {
                  "source_files": [sources_dev.name],
                  "target_files": [targets_dev.name],
              }
          },
          "train_steps": 50,
          "model_params": {
              "embedding.dim": 10,
              "decoder.params": {
                  "rnn_cell": {
                      "cell_class": "GRUCell",
                      "cell_params": {
                          "num_units": 8
                      }
                  }
              },
              "encoder.params": {
                  "rnn_cell": {
                      "cell_class": "GRUCell",
                      "cell_params": {
                          "num_units": 8
                      }
                  }
              }
          }
      }, config_file)

    tf.app.flags.FLAGS.config_paths = config_path

    # Run training
    tf.logging.set_verbosity(tf.logging.INFO)
    train_script.main([])

    # Make sure a checkpoint was written
    expected_checkpoint = os.path.join(self.output_dir,
                                       "model.ckpt-50.data-00000-of-00001")
    self.assertTrue(os.path.exists(expected_checkpoint))

    # Reset flags and import inference script
    _clear_flags()
    tf.reset_default_graph()
    infer_script = imp.load_source("seq2seq.test.infer_bin",
                                   os.path.join(BIN_FOLDER, "infer.py"))

    # Set inference flags
    attention_dir = os.path.join(self.output_dir, "att")
    tf.app.flags.FLAGS.model_dir = self.output_dir
    tf.app.flags.FLAGS.input_pipeline = """
      class: ParallelTextInputPipeline
      params:
        source_files:
          - {}
        target_files:
          - {}
    """.format(sources_dev.name, targets_dev.name)
    tf.app.flags.FLAGS.batch_size = 2
    tf.app.flags.FLAGS.checkpoint_path = os.path.join(self.output_dir,
                                                      "model.ckpt-50")

    # Use DecodeText Task
    tf.app.flags.FLAGS.tasks = """
    - class: DecodeText
    - class: DumpAttention
      params:
        output_dir: {}
    """.format(attention_dir)

    # Make sure inference runs successfully
    infer_script.main([])

    # Make sure attention scores and visualizations exist
    self.assertTrue(
        os.path.exists(os.path.join(attention_dir, "attention_scores.npz")))
    self.assertTrue(os.path.exists(os.path.join(attention_dir, "00002.png")))

    # Load attention scores and assert shape
    scores = np.load(os.path.join(attention_dir, "attention_scores.npz"))
    self.assertIn("arr_0", scores)
    self.assertEqual(scores["arr_0"].shape[1], 3)
    self.assertIn("arr_1", scores)
    self.assertEqual(scores["arr_1"].shape[1], 3)
    self.assertIn("arr_2", scores)
    self.assertEqual(scores["arr_2"].shape[1], 4)
    self.assertIn("arr_3", scores)
    self.assertEqual(scores["arr_3"].shape[1], 4)

    # Test inference with beam search
    _clear_flags()
    tf.reset_default_graph()
    infer_script = imp.load_source("seq2seq.test.infer_bin",
                                   os.path.join(BIN_FOLDER, "infer.py"))

    # Set inference flags
    tf.app.flags.FLAGS.model_dir = self.output_dir
    tf.app.flags.FLAGS.input_pipeline = """
      class: ParallelTextInputPipeline
      params:
        source_files:
          - {}
        target_files:
          - {}
    """.format(sources_dev.name, targets_dev.name)
    tf.app.flags.FLAGS.batch_size = 2
    tf.app.flags.FLAGS.checkpoint_path = os.path.join(self.output_dir,
                                                      "model.ckpt-50")
    tf.app.flags.FLAGS.model_params = """
      inference.beam_search.beam_width: 5
    """
    tf.app.flags.FLAGS.tasks = """
    - class: DecodeText
      params:
        postproc_fn: seq2seq.data.postproc.decode_sentencepiece
    - class: DumpBeams
      params:
        file: {}
    """.format(os.path.join(self.output_dir, "beams.npz"))

    # Run inference w/ beam search
    infer_script.main([])
    self.assertTrue(os.path.exists(os.path.join(self.output_dir, "beams.npz")))
예제 #9
0
  def test_train_infer(self):
    """Tests training and inference scripts.
    """
    # Create dummy data
    sources_train, targets_train = test_utils.create_temp_parallel_data(
        sources=["a a a a", "b b b b", "c c c c", "笑 笑 笑 笑"],
        targets=["b b b b", "a a a a", "c c c c", "泣 泣 泣 泣"])
    sources_dev, targets_dev = test_utils.create_temp_parallel_data(
        sources=["a a", "b b", "c c c", "笑 笑 笑"],
        targets=["b b", "a a", "c c c", "泣 泣 泣"])
    vocab_source = test_utils.create_temporary_vocab_file(["a", "b", "c", "笑"])
    vocab_target = test_utils.create_temporary_vocab_file(["a", "b", "c", "泣"])

    _clear_flags()
    tf.reset_default_graph()
    train_script = imp.load_source("seq2seq.test.train_bin",
                                   os.path.join(BIN_FOLDER, "train.py"))

    # Set training flags
    tf.app.flags.FLAGS.output_dir = self.output_dir
    tf.app.flags.FLAGS.hooks = """
      - class: PrintModelAnalysisHook
      - class: MetadataCaptureHook
      - class: TrainSampleHook
    """
    tf.app.flags.FLAGS.metrics = """
      - class: LogPerplexityMetricSpec
      - class: BleuMetricSpec
      - class: RougeMetricSpec
        params:
          rouge_type: rouge_1/f_score
    """
    tf.app.flags.FLAGS.model = "AttentionSeq2Seq"
    tf.app.flags.FLAGS.model_params = """
    attention.params:
      num_units: 10
    vocab_source: {}
    vocab_target: {}
    """.format(vocab_source.name, vocab_target.name)
    tf.app.flags.FLAGS.batch_size = 2

    # We pass a few flags via a config file
    config_path = os.path.join(self.output_dir, "train_config.yml")
    with gfile.GFile(config_path, "w") as config_file:
      yaml.dump({
          "input_pipeline_train": {
              "class": "ParallelTextInputPipeline",
              "params": {
                  "source_files": [sources_train.name],
                  "target_files": [targets_train.name],
              }
          },
          "input_pipeline_dev": {
              "class": "ParallelTextInputPipeline",
              "params": {
                  "source_files": [sources_dev.name],
                  "target_files": [targets_dev.name],
              }
          },
          "train_steps": 50,
          "model_params": {
              "embedding.dim": 10,
              "decoder.params": {
                  "rnn_cell": {
                      "cell_class": "GRUCell",
                      "cell_params": {
                          "num_units": 8
                      }
                  }
              },
              "encoder.params": {
                  "rnn_cell": {
                      "cell_class": "GRUCell",
                      "cell_params": {
                          "num_units": 8
                      }
                  }
              }
          }
      }, config_file)

    tf.app.flags.FLAGS.config_paths = config_path

    # Run training
    tf.logging.set_verbosity(tf.logging.INFO)
    train_script.main([])

    # Make sure a checkpoint was written
    expected_checkpoint = os.path.join(self.output_dir,
                                       "model.ckpt-50.data-00000-of-00001")
    self.assertTrue(os.path.exists(expected_checkpoint))

    # Reset flags and import inference script
    _clear_flags()
    tf.reset_default_graph()
    infer_script = imp.load_source("seq2seq.test.infer_bin",
                                   os.path.join(BIN_FOLDER, "infer.py"))

    # Set inference flags
    attention_dir = os.path.join(self.output_dir, "att")
    tf.app.flags.FLAGS.model_dir = self.output_dir
    tf.app.flags.FLAGS.input_pipeline = """
      class: ParallelTextInputPipeline
      params:
        source_files:
          - {}
        target_files:
          - {}
    """.format(sources_dev.name, targets_dev.name)
    tf.app.flags.FLAGS.batch_size = 2
    tf.app.flags.FLAGS.checkpoint_path = os.path.join(self.output_dir,
                                                      "model.ckpt-50")

    # Use DecodeText Task
    tf.app.flags.FLAGS.tasks = """
    - class: DecodeText
    - class: DumpAttention
      params:
        output_dir: {}
    """.format(attention_dir)

    # Make sure inference runs successfully
    infer_script.main([])

    # Make sure attention scores and visualizations exist
    self.assertTrue(
        os.path.exists(os.path.join(attention_dir, "attention_scores.npz")))
    self.assertTrue(os.path.exists(os.path.join(attention_dir, "00002.png")))

    # Load attention scores and assert shape
    scores = np.load(os.path.join(attention_dir, "attention_scores.npz"))
    self.assertIn("arr_0", scores)
    self.assertEqual(scores["arr_0"].shape[1], 3)
    self.assertIn("arr_1", scores)
    self.assertEqual(scores["arr_1"].shape[1], 3)
    self.assertIn("arr_2", scores)
    self.assertEqual(scores["arr_2"].shape[1], 4)
    self.assertIn("arr_3", scores)
    self.assertEqual(scores["arr_3"].shape[1], 4)

    # Test inference with beam search
    _clear_flags()
    tf.reset_default_graph()
    infer_script = imp.load_source("seq2seq.test.infer_bin",
                                   os.path.join(BIN_FOLDER, "infer.py"))

    # Set inference flags
    tf.app.flags.FLAGS.model_dir = self.output_dir
    tf.app.flags.FLAGS.input_pipeline = """
      class: ParallelTextInputPipeline
      params:
        source_files:
          - {}
        target_files:
          - {}
    """.format(sources_dev.name, targets_dev.name)
    tf.app.flags.FLAGS.batch_size = 2
    tf.app.flags.FLAGS.checkpoint_path = os.path.join(self.output_dir,
                                                      "model.ckpt-50")
    tf.app.flags.FLAGS.model_params = """
      inference.beam_search.beam_width: 5
    """
    tf.app.flags.FLAGS.tasks = """
    - class: DecodeText
      params:
        postproc_fn: seq2seq.data.postproc.decode_sentencepiece
    - class: DumpBeams
      params:
        file: {}
    """.format(os.path.join(self.output_dir, "beams.npz"))

    # Run inference w/ beam search
    infer_script.main([])
    self.assertTrue(os.path.exists(os.path.join(self.output_dir, "beams.npz")))
예제 #10
0
def test_copy_gen_model(source_path=None, target_path=None, vocab_path=None):

    tf.logging.set_verbosity(tf.logging.INFO)
    batch_size = 2
    input_depth = 4
    sequence_length = 10

    if vocab_path is None:
        # Create vocabulary
        vocab_list = [str(_) for _ in range(10)]
        vocab_list += ["笑う", "泣く", "了解", "はい", "^_^"]
        vocab_size = len(vocab_list)
        vocab_file = test_utils.create_temporary_vocab_file(vocab_list)
        vocab_info = vocab.get_vocab_info(vocab_file.name)
        vocab_path = vocab_file.name
        tf.logging.info(vocab_file.name)
    else:
        vocab_info = vocab.get_vocab_info(vocab_path)
        vocab_list = get_vocab_list(vocab_path)

    extend_vocab = vocab_list + ["中国", "爱", "你"]

    tf.contrib.framework.get_or_create_global_step()
    source_len = sequence_length + 5
    target_len = sequence_length + 10
    source = " ".join(np.random.choice(extend_vocab, source_len))
    target = " ".join(np.random.choice(extend_vocab, target_len))

    is_tmp_file = False
    if source_path is None and target_path is None:
        is_tmp_file = True
        sources_file, targets_file = test_utils.create_temp_parallel_data(
            sources=[source], targets=[target])
        source_path = sources_file.name
        target_path = targets_file.name

    # Build model graph
    mode = tf.contrib.learn.ModeKeys.TRAIN
    params_ = CopyGenSeq2Seq.default_params().copy()
    params_.update({
        "vocab_source": vocab_path,
        "vocab_target": vocab_path,
    })
    model = CopyGenSeq2Seq(params=params_, mode=mode)

    tf.logging.info(source_path)
    tf.logging.info(target_path)

    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={
        "source_files": [source_path],
        "target_files": [target_path]
    },
                                                               mode=mode)
    input_fn = training_utils.create_input_fn(pipeline=input_pipeline_,
                                              batch_size=batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())
        with tf.contrib.slim.queues.QueueRunners(sess):
            fetches_ = sess.run(fetches)

    if is_tmp_file:
        sources_file.close()
        targets_file.close()

    return model, fetches_
    def test_train_infer(self):
        """Tests training and inference scripts.
    """
        # Create dummy data
        sources_train, targets_train = test_utils.create_temp_parallel_data(
            sources=["a a a a", "b b b b", "c c c c", "笑 笑 笑 笑"],
            targets=["b b b b", "a a a a", "c c c c", "泣 泣 泣 泣"])
        sources_dev, targets_dev = test_utils.create_temp_parallel_data(
            sources=["a a", "b b", "c c c", "笑 笑 笑"],
            targets=["b b", "a a", "c c c", "泣 泣 泣"])
        vocab_source = test_utils.create_temporary_vocab_file(
            ["a", "b", "c", "笑"])
        vocab_target = test_utils.create_temporary_vocab_file(
            ["a", "b", "c", "泣"])

        _clear_flags()
        tf.reset_default_graph()
        train_script = imp.load_source("seq2seq.test.train_bin",
                                       os.path.join(BIN_FOLDER, "train.py"))

        # Set training flags
        tf.app.flags.FLAGS.output_dir = self.output_dir
        tf.app.flags.FLAGS.train_source = sources_train.name
        tf.app.flags.FLAGS.train_target = targets_train.name
        tf.app.flags.FLAGS.vocab_source = vocab_source.name
        tf.app.flags.FLAGS.vocab_target = vocab_target.name
        tf.app.flags.FLAGS.model = "AttentionSeq2Seq"
        tf.app.flags.FLAGS.batch_size = 2

        # We pass a few flags via a config file
        config_path = os.path.join(self.output_dir, "train_config.yml")
        with gfile.GFile(config_path, "w") as config_file:
            yaml.dump(
                {
                    "dev_source": sources_dev.name,
                    "dev_target": targets_dev.name,
                    "train_steps": 50,
                    "hparams": {
                        "embedding.dim": 64,
                        "attention.dim": 16,
                        "decoder.rnn_cell.cell_spec": {
                            "class": "GRUCell",
                            "num_units": 32
                        }
                    }
                }, config_file)

        tf.app.flags.FLAGS.config_path = config_path

        # Run training
        tf.logging.set_verbosity(tf.logging.INFO)
        train_script.main([])

        # Make sure a checkpoint was written
        expected_checkpoint = os.path.join(
            self.output_dir, "model.ckpt-50.data-00000-of-00001")
        self.assertTrue(os.path.exists(expected_checkpoint))

        # Reset flags and import inference script
        _clear_flags()
        tf.reset_default_graph()
        infer_script = imp.load_source("seq2seq.test.infer_bin",
                                       os.path.join(BIN_FOLDER, "infer.py"))

        # Set inference flags
        attention_dir = os.path.join(self.output_dir, "att")
        tf.app.flags.FLAGS.model_dir = self.output_dir
        tf.app.flags.FLAGS.source = sources_dev.name
        tf.app.flags.FLAGS.batch_size = 2
        tf.app.flags.FLAGS.checkpoint_path = os.path.join(
            self.output_dir, "model.ckpt-50")
        tf.app.flags.FLAGS.dump_attention_dir = attention_dir

        # Make sure inference runs successfully
        infer_script.main([])

        # Make sure attention scores and visualizations exist
        self.assertTrue(
            os.path.exists(os.path.join(attention_dir,
                                        "attention_scores.npz")))
        self.assertTrue(
            os.path.exists(os.path.join(attention_dir, "00002.png")))

        # Load attention scores and assert shape
        scores = np.load(os.path.join(attention_dir, "attention_scores.npz"))
        self.assertIn("arr_0", scores)
        self.assertEqual(scores["arr_0"].shape[1], 3)
        self.assertIn("arr_1", scores)
        self.assertEqual(scores["arr_1"].shape[1], 3)
        self.assertIn("arr_2", scores)
        self.assertEqual(scores["arr_2"].shape[1], 4)
        self.assertIn("arr_3", scores)
        self.assertEqual(scores["arr_3"].shape[1], 4)