def setUp(self): unittest.TestCase.setUp(self) self.SetBookkeeper(prettytensor.bookkeeper_for_new_graph())
def testArbitraryBatchSizeLstm(self): # Tests whether the LSTM / Bookkeeper function when batch_size is not # specified at graph creation time (i.e., None). super(self.__class__, self).SetBookkeeper( prettytensor.bookkeeper_for_new_graph()) # Build a graph. Specify None for the batch_size dimension. placeholder = tf.placeholder(tf.float32, [None, 1]) input_pt = prettytensor.wrap_sequence([placeholder]) output, _ = (input_pt .sequence_lstm(4) .squash_sequence() .softmax_classifier(2)) self.sess.run(tf.initialize_all_variables()) # Use RecurrentRunner for state saving and managing feeds. recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=1) # Run with a batch size of 1 for 10 steps, save output for reference. out_orig = [] for t in xrange(10): outs = recurrent_runner.run( [output.name], {placeholder.name: numpy.array([[1.2]])}, sess=self.sess) out = outs[0] self.assertEqual(1, len(out)) self.assertEqual(2, len(out[0])) out_orig.append(out[0]) # Test the reset functionality - after a reset, the results must be # identical to what we just got above. recurrent_runner.reset() for t in xrange(10): outs = recurrent_runner.run( [output.name], {placeholder.name: numpy.array([[1.2]])}, sess=self.sess) out = outs[0] self.assertEqual(1, len(out)) self.assertEqual(2, len(out[0])) testing.assert_allclose(out[0], out_orig[t]) # Test whether the recurrent runner detects changes to the default graph. # It should raise an Assertion because RecurrentRunner's state saver # information (collected during __init__) is not valid anymore. with tf.Graph().as_default(): placeholder2 = tf.placeholder(tf.float32, [None, 1]) input_pt2 = prettytensor.wrap_sequence([placeholder2]) output2, _ = (input_pt2 .sequence_lstm(4) .squash_sequence() .softmax_classifier(2)) self.assertRaises(ValueError, recurrent_runner.run, [output2.name], None, self.sess) # Run with a batch size of 3; first and third input are identical and must # yield identical output, and the same output as in the single batch run # above (up to floating point rounding errors). recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=3) for t in xrange(10): outs = recurrent_runner.run( [output.name], {placeholder.name: numpy.array([[1.2], [3.4], [1.2]])}, sess=self.sess) out = outs[0] self.assertEqual(3, len(out)) self.assertEqual(2, len(out[0])) testing.assert_allclose(out[0], out[2], rtol=TOLERANCE) testing.assert_allclose(out[0], out_orig[t], rtol=TOLERANCE) self.assertFalse((out[0] == out[1]).all())
def performTestArbitraryBatchSizeRnn(self, cell_type): # Tests whether LSTM / GRU / Bookkeeper function when batch_size is not # specified at graph creation time (i.e., None). self.assertTrue(cell_type == 'lstm' or cell_type == 'gru') super(self.__class__, self).SetBookkeeper(prettytensor.bookkeeper_for_new_graph()) # Build a graph. Specify None for the batch_size dimension. placeholder = tf.placeholder(tf.float32, [None, 1]) input_pt = prettytensor.wrap_sequence([placeholder]) if cell_type == 'lstm': output, _ = (input_pt.sequence_lstm( 4).squash_sequence().softmax_classifier(2)) elif cell_type == 'gru': output, _ = (input_pt.sequence_gru( 4).squash_sequence().softmax_classifier(2)) self.sess.run(tf.global_variables_initializer()) # Use RecurrentRunner for state saving and managing feeds. recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=1) # Run with a batch size of 1 for 10 steps, save output for reference. out_orig = [] for t in xrange(10): outs = recurrent_runner.run( [output.name], {placeholder.name: numpy.array([[1.2]])}, sess=self.sess) out = outs[0] self.assertEqual(1, len(out)) self.assertEqual(2, len(out[0])) out_orig.append(out[0]) # Test the reset functionality - after a reset, the results must be # identical to what we just got above. recurrent_runner.reset() for t in xrange(10): outs = recurrent_runner.run( [output.name], {placeholder.name: numpy.array([[1.2]])}, sess=self.sess) out = outs[0] self.assertEqual(1, len(out)) self.assertEqual(2, len(out[0])) testing.assert_allclose(out[0], out_orig[t]) # Test whether the recurrent runner detects changes to the default graph. # It should raise an Assertion because RecurrentRunner's state saver # information (collected during __init__) is not valid anymore. with tf.Graph().as_default(): placeholder2 = tf.placeholder(tf.float32, [None, 1]) input_pt2 = prettytensor.wrap_sequence([placeholder2]) if cell_type == 'lstm': output2, _ = (input_pt2.sequence_lstm( 4).squash_sequence().softmax_classifier(2)) elif cell_type == 'gru': output2, _ = (input_pt2.sequence_gru( 4).squash_sequence().softmax_classifier(2)) self.assertRaises(ValueError, recurrent_runner.run, [output2.name], None, self.sess) # Run with a batch size of 3; first and third input are identical and must # yield identical output, and the same output as in the single batch run # above (up to floating point rounding errors). recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=3) for t in xrange(10): outs = recurrent_runner.run( [output.name], {placeholder.name: numpy.array([[1.2], [3.4], [1.2]])}, sess=self.sess) out = outs[0] self.assertEqual(3, len(out)) self.assertEqual(2, len(out[0])) testing.assert_allclose(out[0], out[2], rtol=TOLERANCE) testing.assert_allclose(out[0], out_orig[t], rtol=TOLERANCE) # Sanity check to protect against trivial outputs that might hide errors. # Need to avoid checking after t = 2 since untrained GRUs have a # tendency to converge to large state values, leading to outputs like # 1.0, 0.0. if cell_type == 'gru' and t > 2: continue self.assertFalse((out[0] == out[1]).all())