예제 #1
0
    def testSnliData(self):
        snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
        fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
        os.makedirs(snli_1_0_dir)
        self._createFakeSnliData(fake_train_file)

        glove_dir = os.path.join(self._temp_data_dir, "glove")
        os.makedirs(glove_dir)
        glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
        self._createFakeGloveData(glove_file)

        vocab = data.load_vocabulary(self._temp_data_dir)
        word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)

        train_data = data.SnliData(fake_train_file, word2index)
        self.assertEqual(4, train_data.num_batches(1))
        self.assertEqual(2, train_data.num_batches(2))
        self.assertEqual(2, train_data.num_batches(3))
        self.assertEqual(1, train_data.num_batches(4))

        generator = train_data.get_generator(2)()
        for _ in range(2):
            label, prem, prem_trans, hypo, hypo_trans = next(generator)
            self.assertEqual(2, len(label))
            self.assertEqual((4, 2), prem.shape)
            self.assertEqual((5, 2), prem_trans.shape)
            self.assertEqual((3, 2), hypo.shape)
            self.assertEqual((3, 2), hypo_trans.shape)
예제 #2
0
  def testLoadVoacbulary(self):
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
    fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt")
    os.makedirs(snli_1_0_dir)

    with open(fake_train_file, "wt") as f:
      f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
              "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
              "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
      f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Foo bar.\tfoo baz.\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")
    with open(fake_dev_file, "wt") as f:
      f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
              "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
              "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
      f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Quux quuz?\t.Corge grault!\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")

    vocab = data.load_vocabulary(self._temp_data_dir)
    self.assertSetEqual(
        {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"},
        vocab)
예제 #3
0
  def testSnliData(self):
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
    os.makedirs(snli_1_0_dir)
    self._createFakeSnliData(fake_train_file)

    glove_dir = os.path.join(self._temp_data_dir, "glove")
    os.makedirs(glove_dir)
    glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
    self._createFakeGloveData(glove_file)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    self.assertEqual(4, train_data.num_batches(1))
    self.assertEqual(2, train_data.num_batches(2))
    self.assertEqual(2, train_data.num_batches(3))
    self.assertEqual(1, train_data.num_batches(4))

    generator = train_data.get_generator(2)()
    for _ in range(2):
      label, prem, prem_trans, hypo, hypo_trans = next(generator)
      self.assertEqual(2, len(label))
      self.assertEqual((4, 2), prem.shape)
      self.assertEqual((5, 2), prem_trans.shape)
      self.assertEqual((3, 2), hypo.shape)
      self.assertEqual((3, 2), hypo_trans.shape)
예제 #4
0
def main(_):
  config = FLAGS

  # Load embedding vectors.
  vocab = data.load_vocabulary(FLAGS.data_root)
  word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)

  if not (config.inference_premise or config.inference_hypothesis):
    print("Loading train, dev and test data...")
    train_data = data.SnliData(
        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
    dev_data = data.SnliData(
        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
    test_data = data.SnliData(
        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
  else:
    train_data = None
    dev_data = None
    test_data = None

  train_or_infer_spinn(
      embed, word2index, train_data, dev_data, test_data, config)
예제 #5
0
def main(_):
  config = FLAGS

  # Load embedding vectors.
  vocab = data.load_vocabulary(FLAGS.data_root)
  word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)

  if not (config.inference_premise or config.inference_hypothesis):
    print("Loading train, dev and test data...")
    train_data = data.SnliData(
        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
    dev_data = data.SnliData(
        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
    test_data = data.SnliData(
        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
  else:
    train_data = None
    dev_data = None
    test_data = None

  train_or_infer_spinn(
      embed, word2index, train_data, dev_data, test_data, config)
예제 #6
0
  def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self):
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"),
        inference_sentences=("( foo ( bar . ) )", None))
    with self.assertRaises(ValueError):
      spinn.train_or_infer_spinn(embed, word2index, None, None, None, config)
예제 #7
0
    def testTrainSpinn(self):
        """Test with fake toy SNLI data and GloVe vectors."""

        # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
        snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
        fake_train_file = self._create_test_data(snli_1_0_dir)

        vocab = data.load_vocabulary(self._temp_data_dir)
        word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

        train_data = data.SnliData(fake_train_file, word2index)
        dev_data = data.SnliData(fake_train_file, word2index)
        test_data = data.SnliData(fake_train_file, word2index)

        # 2. Create a fake config.
        config = _test_spinn_config(data.WORD_VECTOR_LEN,
                                    4,
                                    logdir=os.path.join(
                                        self._temp_data_dir, "logdir"))

        # 3. Test training of a SPINN model.
        trainer = spinn.train_or_infer_spinn(embed, word2index, train_data,
                                             dev_data, test_data, config)

        # 4. Load train loss values from the summary files and verify that they
        #    decrease with training.
        summary_file = glob.glob(os.path.join(config.logdir,
                                              "events.out.*"))[0]
        events = summary_test_util.events_from_file(summary_file)
        train_losses = [
            event.summary.value[0].simple_value for event in events if
            event.summary.value and event.summary.value[0].tag == "train/loss"
        ]
        self.assertEqual(config.epochs, len(train_losses))

        # 5. Verify that checkpoints exist and contains all the expected variables.
        self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
        object_graph_string = checkpoint_utils.load_variable(
            config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH")
        object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph(
        )
        object_graph.ParseFromString(object_graph_string)
        ckpt_variable_names = set()
        for node in object_graph.nodes:
            for attribute in node.attributes:
                ckpt_variable_names.add(attribute.full_name)
        self.assertIn("global_step", ckpt_variable_names)
        for v in trainer.variables:
            variable_name = v.name[:v.name.
                                   index(":")] if ":" in v.name else v.name
            self.assertIn(variable_name, ckpt_variable_names)
예제 #8
0
  def testInferSpinnWorks(self):
    """Test inference with the spinn model."""
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"),
        inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )"))
    logits = spinn.train_or_infer_spinn(
        embed, word2index, None, None, None, config)
    self.assertEqual(tf.float32, logits.dtype)
    self.assertEqual((3,), logits.shape)
예제 #9
0
  def testLoadVoacbularyWithoutFileRaisesError(self):
    with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
      data.load_vocabulary(self._temp_data_dir)

    os.makedirs(os.path.join(self._temp_data_dir, "snli"))
    with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
      data.load_vocabulary(self._temp_data_dir)

    os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0"))
    with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
      data.load_vocabulary(self._temp_data_dir)
예제 #10
0
  def testTrainSpinn(self):
    """Test with fake toy SNLI data and GloVe vectors."""

    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    dev_data = data.SnliData(fake_train_file, word2index)
    test_data = data.SnliData(fake_train_file, word2index)

    # 2. Create a fake config.
    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"))

    # 3. Test training of a SPINN model.
    trainer = spinn.train_or_infer_spinn(
        embed, word2index, train_data, dev_data, test_data, config)

    # 4. Load train loss values from the summary files and verify that they
    #    decrease with training.
    summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
    events = summary_test_util.events_from_file(summary_file)
    train_losses = [event.summary.value[0].simple_value for event in events
                    if event.summary.value
                    and event.summary.value[0].tag == "train/loss"]
    self.assertEqual(config.epochs, len(train_losses))

    # 5. Verify that checkpoints exist and contains all the expected variables.
    self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
    object_graph_string = checkpoint_utils.load_variable(
        config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH")
    object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph()
    object_graph.ParseFromString(object_graph_string)
    ckpt_variable_names = set()
    for node in object_graph.nodes:
      for attribute in node.attributes:
        ckpt_variable_names.add(attribute.full_name)
    self.assertIn("global_step", ckpt_variable_names)
    for v in trainer.variables:
      variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
      self.assertIn(variable_name, ckpt_variable_names)
예제 #11
0
def main(_):
  config = FLAGS

  # Load embedding vectors.
  vocab = data.load_vocabulary(FLAGS.data_root)
  word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)

  print("Loading train, dev and test data...")
  train_data = data.SnliData(
      os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
      word2index, sentence_len_limit=FLAGS.sentence_len_limit)
  dev_data = data.SnliData(
      os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
      word2index, sentence_len_limit=FLAGS.sentence_len_limit)
  test_data = data.SnliData(
      os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
      word2index, sentence_len_limit=FLAGS.sentence_len_limit)

  train_spinn(embed, train_data, dev_data, test_data, config)
예제 #12
0
  def testEncodeSingleSentence(self):
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
    os.makedirs(snli_1_0_dir)
    self._createFakeSnliData(fake_train_file)
    vocab = data.load_vocabulary(self._temp_data_dir)
    glove_dir = os.path.join(self._temp_data_dir, "glove")
    os.makedirs(glove_dir)
    glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
    self._createFakeGloveData(glove_file)
    word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)

    sentence_variants = [
        "( Foo ( ( bar baz ) . ) )",
        " ( Foo ( ( bar baz ) . ) ) ",
        "( Foo ( ( bar baz ) . )  )"]
    for sentence in sentence_variants:
      word_indices, shift_reduce = data.encode_sentence(sentence, word2index)
      self.assertEqual(np.int64, word_indices.dtype)
      self.assertEqual((5, 1), word_indices.shape)
      self.assertAllClose(
          np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce)
예제 #13
0
파일: spinn.py 프로젝트: lengjia/RRL
def main(_):
    config = FLAGS

    # Load embedding vectors.
    vocab = data.load_vocabulary(FLAGS.data_root)
    word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)

    print("Loading train, dev and test data...")
    train_data = data.SnliData(os.path.join(
        FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
                               word2index,
                               sentence_len_limit=FLAGS.sentence_len_limit)
    dev_data = data.SnliData(os.path.join(FLAGS.data_root,
                                          "snli/snli_1.0/snli_1.0_dev.txt"),
                             word2index,
                             sentence_len_limit=FLAGS.sentence_len_limit)
    test_data = data.SnliData(os.path.join(FLAGS.data_root,
                                           "snli/snli_1.0/snli_1.0_test.txt"),
                              word2index,
                              sentence_len_limit=FLAGS.sentence_len_limit)

    train_spinn(embed, train_data, dev_data, test_data, config)
예제 #14
0
    def testEncodeSingleSentence(self):
        snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
        fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
        os.makedirs(snli_1_0_dir)
        self._createFakeSnliData(fake_train_file)
        vocab = data.load_vocabulary(self._temp_data_dir)
        glove_dir = os.path.join(self._temp_data_dir, "glove")
        os.makedirs(glove_dir)
        glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
        self._createFakeGloveData(glove_file)
        word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)

        sentence_variants = [
            "( Foo ( ( bar baz ) . ) )", " ( Foo ( ( bar baz ) . ) ) ",
            "( Foo ( ( bar baz ) . )  )"
        ]
        for sentence in sentence_variants:
            word_indices, shift_reduce = data.encode_sentence(
                sentence, word2index)
            self.assertEqual(np.int64, word_indices.dtype)
            self.assertEqual((5, 1), word_indices.shape)
            self.assertAllClose(
                np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T,
                shift_reduce)
예제 #15
0
  def testTrainSpinn(self):
    """Test with fake toy SNLI data and GloVe vectors."""

    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
    os.makedirs(snli_1_0_dir)

    # Four sentences in total.
    with open(fake_train_file, "wt") as f:
      f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
              "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
              "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
      f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Foo bar.\tfoo baz.\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")
      f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Foo bar.\tfoo baz.\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")
      f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Foo bar.\tfoo baz.\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")
      f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Foo bar.\tfoo baz.\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")

    glove_dir = os.path.join(self._temp_data_dir, "glove")
    os.makedirs(glove_dir)
    glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")

    words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
    with open(glove_file, "wt") as f:
      for i, word in enumerate(words):
        f.write("%s " % word)
        for j in range(data.WORD_VECTOR_LEN):
          f.write("%.5f" % (i * 0.1))
          if j < data.WORD_VECTOR_LEN - 1:
            f.write(" ")
          else:
            f.write("\n")

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    dev_data = data.SnliData(fake_train_file, word2index)
    test_data = data.SnliData(fake_train_file, word2index)
    print(embed)

    # 2. Create a fake config.
    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"))

    # 3. Test training of a SPINN model.
    spinn.train_spinn(embed, train_data, dev_data, test_data, config)

    # 4. Load train loss values from the summary files and verify that they
    #    decrease with training.
    summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
    events = summary_test_util.events_from_file(summary_file)
    train_losses = [event.summary.value[0].simple_value for event in events
                    if event.summary.value
                    and event.summary.value[0].tag == "train/loss"]
    self.assertEqual(config.epochs, len(train_losses))
    self.assertLess(train_losses[-1], train_losses[0])
예제 #16
0
    def testTrainSpinn(self):
        """Test with fake toy SNLI data and GloVe vectors."""

        # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
        snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
        fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
        os.makedirs(snli_1_0_dir)

        # Four sentences in total.
        with open(fake_train_file, "wt") as f:
            f.write(
                "gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
                "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
                "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
            f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t"
                    "DummySentence1Parse\tDummySentence2Parse\t"
                    "Foo bar.\tfoo baz.\t"
                    "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
                    "neutral\tentailment\tneutral\tneutral\tneutral\n")
            f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t"
                    "DummySentence1Parse\tDummySentence2Parse\t"
                    "Foo bar.\tfoo baz.\t"
                    "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
                    "neutral\tentailment\tneutral\tneutral\tneutral\n")
            f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t"
                    "DummySentence1Parse\tDummySentence2Parse\t"
                    "Foo bar.\tfoo baz.\t"
                    "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
                    "neutral\tentailment\tneutral\tneutral\tneutral\n")
            f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t"
                    "DummySentence1Parse\tDummySentence2Parse\t"
                    "Foo bar.\tfoo baz.\t"
                    "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
                    "neutral\tentailment\tneutral\tneutral\tneutral\n")

        glove_dir = os.path.join(self._temp_data_dir, "glove")
        os.makedirs(glove_dir)
        glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")

        words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
        with open(glove_file, "wt") as f:
            for i, word in enumerate(words):
                f.write("%s " % word)
                for j in range(data.WORD_VECTOR_LEN):
                    f.write("%.5f" % (i * 0.1))
                    if j < data.WORD_VECTOR_LEN - 1:
                        f.write(" ")
                    else:
                        f.write("\n")

        vocab = data.load_vocabulary(self._temp_data_dir)
        word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

        train_data = data.SnliData(fake_train_file, word2index)
        dev_data = data.SnliData(fake_train_file, word2index)
        test_data = data.SnliData(fake_train_file, word2index)
        print(embed)

        # 2. Create a fake config.
        config = _test_spinn_config(data.WORD_VECTOR_LEN,
                                    4,
                                    logdir=os.path.join(
                                        self._temp_data_dir, "logdir"))

        # 3. Test training of a SPINN model.
        spinn.train_spinn(embed, train_data, dev_data, test_data, config)

        # 4. Load train loss values from the summary files and verify that they
        #    decrease with training.
        summary_file = glob.glob(os.path.join(config.logdir,
                                              "events.out.*"))[0]
        events = summary_test_util.events_from_file(summary_file)
        train_losses = [
            event.summary.value[0].simple_value for event in events if
            event.summary.value and event.summary.value[0].tag == "train/loss"
        ]
        self.assertEqual(config.epochs, len(train_losses))
        self.assertLess(train_losses[-1], train_losses[0])
예제 #17
0
  def testSnliData(self):
    """Unit test for SnliData objects."""
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
    os.makedirs(snli_1_0_dir)

    # Four sentences in total.
    with open(fake_train_file, "wt") as f:
      f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
              "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
              "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
      f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Foo bar.\tfoo baz.\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")
      f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Foo bar.\tfoo baz.\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")
      f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Foo bar.\tfoo baz.\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")
      f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t"
              "DummySentence1Parse\tDummySentence2Parse\t"
              "Foo bar.\tfoo baz.\t"
              "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
              "neutral\tentailment\tneutral\tneutral\tneutral\n")

    glove_dir = os.path.join(self._temp_data_dir, "glove")
    os.makedirs(glove_dir)
    glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")

    words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
    with open(glove_file, "wt") as f:
      for i, word in enumerate(words):
        f.write("%s " % word)
        for j in range(data.WORD_VECTOR_LEN):
          f.write("%.5f" % (i * 0.1))
          if j < data.WORD_VECTOR_LEN - 1:
            f.write(" ")
          else:
            f.write("\n")

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    self.assertEqual(4, train_data.num_batches(1))
    self.assertEqual(2, train_data.num_batches(2))
    self.assertEqual(2, train_data.num_batches(3))
    self.assertEqual(1, train_data.num_batches(4))

    generator = train_data.get_generator(2)()
    for i in range(2):
      label, prem, prem_trans, hypo, hypo_trans = next(generator)
      self.assertEqual(2, len(label))
      self.assertEqual((4, 2), prem.shape)
      self.assertEqual((5, 2), prem_trans.shape)
      self.assertEqual((3, 2), hypo.shape)
      self.assertEqual((3, 2), hypo_trans.shape)