def main(_): config = FLAGS # Load embedding vectors. vocab = data.load_vocabulary(FLAGS.data_root) word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab) if not (config.inference_premise or config.inference_hypothesis): print("Loading train, dev and test data...") train_data = data.SnliData( os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), word2index, sentence_len_limit=FLAGS.sentence_len_limit) dev_data = data.SnliData( os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), word2index, sentence_len_limit=FLAGS.sentence_len_limit) test_data = data.SnliData( os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), word2index, sentence_len_limit=FLAGS.sentence_len_limit) else: train_data = None dev_data = None test_data = None train_or_infer_spinn( embed, word2index, train_data, dev_data, test_data, config)
def testTrainSpinn(self): """Test with fake toy SNLI data and GloVe vectors.""" # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") fake_train_file = self._create_test_data(snli_1_0_dir) vocab = data.load_vocabulary(self._temp_data_dir) word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) dev_data = data.SnliData(fake_train_file, word2index) test_data = data.SnliData(fake_train_file, word2index) # 2. Create a fake config. config = _test_spinn_config(data.WORD_VECTOR_LEN, 4, logdir=os.path.join( self._temp_data_dir, "logdir")) # 3. Test training of a SPINN model. trainer = spinn.train_or_infer_spinn(embed, word2index, train_data, dev_data, test_data, config) # 4. Load train loss values from the summary files and verify that they # decrease with training. summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] events = summary_test_util.events_from_file(summary_file) train_losses = [ event.summary.value[0].simple_value for event in events if event.summary.value and event.summary.value[0].tag == "train/loss" ] self.assertEqual(config.epochs, len(train_losses)) # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) object_graph_string = checkpoint_utils.load_variable( config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH") object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph( ) object_graph.ParseFromString(object_graph_string) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: ckpt_variable_names.add(attribute.full_name) self.assertIn("global_step", ckpt_variable_names) for v in trainer.variables: variable_name = v.name[:v.name. index(":")] if ":" in v.name else v.name self.assertIn(variable_name, ckpt_variable_names)
def testSnliData(self): snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") os.makedirs(snli_1_0_dir) self._createFakeSnliData(fake_train_file) glove_dir = os.path.join(self._temp_data_dir, "glove") os.makedirs(glove_dir) glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") self._createFakeGloveData(glove_file) vocab = data.load_vocabulary(self._temp_data_dir) word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) self.assertEqual(4, train_data.num_batches(1)) self.assertEqual(2, train_data.num_batches(2)) self.assertEqual(2, train_data.num_batches(3)) self.assertEqual(1, train_data.num_batches(4)) generator = train_data.get_generator(2)() for _ in range(2): label, prem, prem_trans, hypo, hypo_trans = next(generator) self.assertEqual(2, len(label)) self.assertEqual((4, 2), prem.shape) self.assertEqual((5, 2), prem_trans.shape) self.assertEqual((3, 2), hypo.shape) self.assertEqual((3, 2), hypo_trans.shape)
def main(_): config = FLAGS # Load embedding vectors. vocab = data.load_vocabulary(FLAGS.data_root) word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab) print("Loading train, dev and test data...") train_data = data.SnliData(os.path.join( FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), word2index, sentence_len_limit=FLAGS.sentence_len_limit) dev_data = data.SnliData(os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), word2index, sentence_len_limit=FLAGS.sentence_len_limit) test_data = data.SnliData(os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), word2index, sentence_len_limit=FLAGS.sentence_len_limit) train_spinn(embed, train_data, dev_data, test_data, config)
def testTrainSpinn(self): """Test with fake toy SNLI data and GloVe vectors.""" # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") os.makedirs(snli_1_0_dir) # Four sentences in total. with open(fake_train_file, "wt") as f: f.write( "gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" "DummySentence1Parse\tDummySentence2Parse\t" "Foo bar.\tfoo baz.\t" "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" "DummySentence1Parse\tDummySentence2Parse\t" "Foo bar.\tfoo baz.\t" "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" "DummySentence1Parse\tDummySentence2Parse\t" "Foo bar.\tfoo baz.\t" "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" "DummySentence1Parse\tDummySentence2Parse\t" "Foo bar.\tfoo baz.\t" "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") glove_dir = os.path.join(self._temp_data_dir, "glove") os.makedirs(glove_dir) glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] with open(glove_file, "wt") as f: for i, word in enumerate(words): f.write("%s " % word) for j in range(data.WORD_VECTOR_LEN): f.write("%.5f" % (i * 0.1)) if j < data.WORD_VECTOR_LEN - 1: f.write(" ") else: f.write("\n") vocab = data.load_vocabulary(self._temp_data_dir) word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) dev_data = data.SnliData(fake_train_file, word2index) test_data = data.SnliData(fake_train_file, word2index) print(embed) # 2. Create a fake config. config = _test_spinn_config(data.WORD_VECTOR_LEN, 4, logdir=os.path.join( self._temp_data_dir, "logdir")) # 3. Test training of a SPINN model. spinn.train_spinn(embed, train_data, dev_data, test_data, config) # 4. Load train loss values from the summary files and verify that they # decrease with training. summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] events = summary_test_util.events_from_file(summary_file) train_losses = [ event.summary.value[0].simple_value for event in events if event.summary.value and event.summary.value[0].tag == "train/loss" ] self.assertEqual(config.epochs, len(train_losses)) self.assertLess(train_losses[-1], train_losses[0])
def testSnliData(self): """Unit test for SnliData objects.""" snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") os.makedirs(snli_1_0_dir) # Four sentences in total. with open(fake_train_file, "wt") as f: f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" "DummySentence1Parse\tDummySentence2Parse\t" "Foo bar.\tfoo baz.\t" "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" "DummySentence1Parse\tDummySentence2Parse\t" "Foo bar.\tfoo baz.\t" "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" "DummySentence1Parse\tDummySentence2Parse\t" "Foo bar.\tfoo baz.\t" "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" "DummySentence1Parse\tDummySentence2Parse\t" "Foo bar.\tfoo baz.\t" "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") glove_dir = os.path.join(self._temp_data_dir, "glove") os.makedirs(glove_dir) glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] with open(glove_file, "wt") as f: for i, word in enumerate(words): f.write("%s " % word) for j in range(data.WORD_VECTOR_LEN): f.write("%.5f" % (i * 0.1)) if j < data.WORD_VECTOR_LEN - 1: f.write(" ") else: f.write("\n") vocab = data.load_vocabulary(self._temp_data_dir) word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) self.assertEqual(4, train_data.num_batches(1)) self.assertEqual(2, train_data.num_batches(2)) self.assertEqual(2, train_data.num_batches(3)) self.assertEqual(1, train_data.num_batches(4)) generator = train_data.get_generator(2)() for i in range(2): label, prem, prem_trans, hypo, hypo_trans = next(generator) self.assertEqual(2, len(label)) self.assertEqual((4, 2), prem.shape) self.assertEqual((5, 2), prem_trans.shape) self.assertEqual((3, 2), hypo.shape) self.assertEqual((3, 2), hypo_trans.shape)