def testSNLIClassifierAndTrainer(self): with tf.device(self._test_device): vocab_size = 40 batch_size = 2 d_embed = 10 sequence_length = 15 d_out = 4 config = _test_spinn_config(d_embed, d_out) # Create fake embedding matrix. embed = tf.random_normal((vocab_size, d_embed)) model = spinn.SNLIClassifier(config, embed) trainer = spinn.SNLIClassifierTrainer(model, config.lr) (labels, prem, prem_trans, hypo, hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, batch_size, vocab_size) # Invoke model under non-training mode. logits = model(prem, prem_trans, hypo, hypo_trans, training=False) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) # Invoke model under training model. logits = model(prem, prem_trans, hypo, hypo_trans, training=True) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) # Calculate loss. loss1 = trainer.loss(labels, logits) self.assertEqual(tf.float32, loss1.dtype) self.assertEqual((), loss1.shape) loss2, logits = trainer.train_batch( labels, prem, prem_trans, hypo, hypo_trans) self.assertEqual(tf.float32, loss2.dtype) self.assertEqual((), loss2.shape) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) # Training on the batch should have led to a change in the loss value. self.assertNotEqual(loss1.numpy(), loss2.numpy())
def benchmarkEagerSpinnSNLIClassifier(self): test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" with tf.device(test_device): burn_in_iterations = 2 benchmark_iterations = 10 vocab_size = 1000 batch_size = 128 sequence_length = 15 d_embed = 200 d_out = 4 embed = tf.random_normal((vocab_size, d_embed)) config = _test_spinn_config(d_embed, d_out) model = spinn.SNLIClassifier(config, embed) trainer = spinn.SNLIClassifierTrainer(model, config.lr) (labels, prem, prem_trans, hypo, hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, batch_size, vocab_size) for _ in range(burn_in_iterations): trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) gc.collect() start_time = time.time() for _ in xrange(benchmark_iterations): trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) wall_time = time.time() - start_time # Named "examples"_per_sec to conform with other benchmarks. extras = {"examples_per_sec": benchmark_iterations / wall_time} self.report_benchmark( name="Eager_SPINN_SNLIClassifier_Benchmark", iters=benchmark_iterations, wall_time=wall_time, extras=extras)