예제 #1
0
  def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self):
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"),
        inference_sentences=("( foo ( bar . ) )", None))
    with self.assertRaises(ValueError):
      spinn.train_or_infer_spinn(embed, word2index, None, None, None, config)
예제 #2
0
    def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self):
        snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
        self._create_test_data(snli_1_0_dir)

        vocab = data.load_vocabulary(self._temp_data_dir)
        word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

        config = _test_spinn_config(
            data.WORD_VECTOR_LEN,
            4,
            logdir=os.path.join(self._temp_data_dir, "logdir"),
            inference_sentences=("( foo ( bar . ) )", None))
        with self.assertRaises(ValueError):
            spinn.train_or_infer_spinn(embed, word2index, None, None, None,
                                       config)
예제 #3
0
    def testTrainSpinn(self):
        """Test with fake toy SNLI data and GloVe vectors."""

        # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
        snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
        fake_train_file = self._create_test_data(snli_1_0_dir)

        vocab = data.load_vocabulary(self._temp_data_dir)
        word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

        train_data = data.SnliData(fake_train_file, word2index)
        dev_data = data.SnliData(fake_train_file, word2index)
        test_data = data.SnliData(fake_train_file, word2index)

        # 2. Create a fake config.
        config = _test_spinn_config(data.WORD_VECTOR_LEN,
                                    4,
                                    logdir=os.path.join(
                                        self._temp_data_dir, "logdir"))

        # 3. Test training of a SPINN model.
        trainer = spinn.train_or_infer_spinn(embed, word2index, train_data,
                                             dev_data, test_data, config)

        # 4. Load train loss values from the summary files and verify that they
        #    decrease with training.
        summary_file = glob.glob(os.path.join(config.logdir,
                                              "events.out.*"))[0]
        events = summary_test_util.events_from_file(summary_file)
        train_losses = [
            event.summary.value[0].simple_value for event in events if
            event.summary.value and event.summary.value[0].tag == "train/loss"
        ]
        self.assertEqual(config.epochs, len(train_losses))

        # 5. Verify that checkpoints exist and contains all the expected variables.
        self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
        object_graph_string = checkpoint_utils.load_variable(
            config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH")
        object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph(
        )
        object_graph.ParseFromString(object_graph_string)
        ckpt_variable_names = set()
        for node in object_graph.nodes:
            for attribute in node.attributes:
                ckpt_variable_names.add(attribute.full_name)
        self.assertIn("global_step", ckpt_variable_names)
        for v in trainer.variables:
            variable_name = v.name[:v.name.
                                   index(":")] if ":" in v.name else v.name
            self.assertIn(variable_name, ckpt_variable_names)
예제 #4
0
  def testInferSpinnWorks(self):
    """Test inference with the spinn model."""
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"),
        inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )"))
    logits = spinn.train_or_infer_spinn(
        embed, word2index, None, None, None, config)
    self.assertEqual(tf.float32, logits.dtype)
    self.assertEqual((3,), logits.shape)
예제 #5
0
  def testTrainSpinn(self):
    """Test with fake toy SNLI data and GloVe vectors."""

    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    dev_data = data.SnliData(fake_train_file, word2index)
    test_data = data.SnliData(fake_train_file, word2index)

    # 2. Create a fake config.
    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"))

    # 3. Test training of a SPINN model.
    trainer = spinn.train_or_infer_spinn(
        embed, word2index, train_data, dev_data, test_data, config)

    # 4. Load train loss values from the summary files and verify that they
    #    decrease with training.
    summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
    events = summary_test_util.events_from_file(summary_file)
    train_losses = [event.summary.value[0].simple_value for event in events
                    if event.summary.value
                    and event.summary.value[0].tag == "train/loss"]
    self.assertEqual(config.epochs, len(train_losses))

    # 5. Verify that checkpoints exist and contains all the expected variables.
    self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
    object_graph_string = checkpoint_utils.load_variable(
        config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH")
    object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph()
    object_graph.ParseFromString(object_graph_string)
    ckpt_variable_names = set()
    for node in object_graph.nodes:
      for attribute in node.attributes:
        ckpt_variable_names.add(attribute.full_name)
    self.assertIn("global_step", ckpt_variable_names)
    for v in trainer.variables:
      variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
      self.assertIn(variable_name, ckpt_variable_names)