def setUp(self):
     unittest.TestCase.setUp(self)
     self.SetBookkeeper(prettytensor.bookkeeper_for_new_graph())
  def testArbitraryBatchSizeLstm(self):
    # Tests whether the LSTM / Bookkeeper function when batch_size is not
    # specified at graph creation time (i.e., None).
    super(self.__class__, self).SetBookkeeper(
        prettytensor.bookkeeper_for_new_graph())

    # Build a graph. Specify None for the batch_size dimension.
    placeholder = tf.placeholder(tf.float32, [None, 1])
    input_pt = prettytensor.wrap_sequence([placeholder])
    output, _ = (input_pt
                 .sequence_lstm(4)
                 .squash_sequence()
                 .softmax_classifier(2))

    self.sess.run(tf.initialize_all_variables())

    # Use RecurrentRunner for state saving and managing feeds.
    recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=1)

    # Run with a batch size of 1 for 10 steps, save output for reference.
    out_orig = []
    for t in xrange(10):
      outs = recurrent_runner.run(
          [output.name],
          {placeholder.name: numpy.array([[1.2]])},
          sess=self.sess)
      out = outs[0]
      self.assertEqual(1, len(out))
      self.assertEqual(2, len(out[0]))
      out_orig.append(out[0])

    # Test the reset functionality - after a reset, the results must be
    # identical to what we just got above.
    recurrent_runner.reset()
    for t in xrange(10):
      outs = recurrent_runner.run(
          [output.name],
          {placeholder.name: numpy.array([[1.2]])},
          sess=self.sess)
      out = outs[0]
      self.assertEqual(1, len(out))
      self.assertEqual(2, len(out[0]))
      testing.assert_allclose(out[0], out_orig[t])

    # Test whether the recurrent runner detects changes to the default graph.
    # It should raise an Assertion because RecurrentRunner's state saver
    # information (collected during __init__) is not valid anymore.
    with tf.Graph().as_default():
      placeholder2 = tf.placeholder(tf.float32, [None, 1])
      input_pt2 = prettytensor.wrap_sequence([placeholder2])
      output2, _ = (input_pt2
                    .sequence_lstm(4)
                    .squash_sequence()
                    .softmax_classifier(2))
      self.assertRaises(ValueError,
                        recurrent_runner.run,
                        [output2.name], None, self.sess)

    # Run with a batch size of 3; first and third input are identical and must
    # yield identical output, and the same output as in the single batch run
    # above (up to floating point rounding errors).
    recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=3)
    for t in xrange(10):
      outs = recurrent_runner.run(
          [output.name],
          {placeholder.name: numpy.array([[1.2], [3.4], [1.2]])},
          sess=self.sess)
      out = outs[0]
      self.assertEqual(3, len(out))
      self.assertEqual(2, len(out[0]))
      testing.assert_allclose(out[0], out[2], rtol=TOLERANCE)
      testing.assert_allclose(out[0], out_orig[t], rtol=TOLERANCE)
      self.assertFalse((out[0] == out[1]).all())
 def setUp(self):
   unittest.TestCase.setUp(self)
   self.SetBookkeeper(prettytensor.bookkeeper_for_new_graph())
示例#4
0
    def performTestArbitraryBatchSizeRnn(self, cell_type):
        # Tests whether LSTM / GRU / Bookkeeper function when batch_size is not
        # specified at graph creation time (i.e., None).
        self.assertTrue(cell_type == 'lstm' or cell_type == 'gru')
        super(self.__class__,
              self).SetBookkeeper(prettytensor.bookkeeper_for_new_graph())

        # Build a graph. Specify None for the batch_size dimension.
        placeholder = tf.placeholder(tf.float32, [None, 1])
        input_pt = prettytensor.wrap_sequence([placeholder])
        if cell_type == 'lstm':
            output, _ = (input_pt.sequence_lstm(
                4).squash_sequence().softmax_classifier(2))
        elif cell_type == 'gru':
            output, _ = (input_pt.sequence_gru(
                4).squash_sequence().softmax_classifier(2))

        self.sess.run(tf.global_variables_initializer())

        # Use RecurrentRunner for state saving and managing feeds.
        recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=1)

        # Run with a batch size of 1 for 10 steps, save output for reference.
        out_orig = []
        for t in xrange(10):
            outs = recurrent_runner.run(
                [output.name], {placeholder.name: numpy.array([[1.2]])},
                sess=self.sess)
            out = outs[0]
            self.assertEqual(1, len(out))
            self.assertEqual(2, len(out[0]))
            out_orig.append(out[0])

        # Test the reset functionality - after a reset, the results must be
        # identical to what we just got above.
        recurrent_runner.reset()
        for t in xrange(10):
            outs = recurrent_runner.run(
                [output.name], {placeholder.name: numpy.array([[1.2]])},
                sess=self.sess)
            out = outs[0]
            self.assertEqual(1, len(out))
            self.assertEqual(2, len(out[0]))
            testing.assert_allclose(out[0], out_orig[t])

        # Test whether the recurrent runner detects changes to the default graph.
        # It should raise an Assertion because RecurrentRunner's state saver
        # information (collected during __init__) is not valid anymore.
        with tf.Graph().as_default():
            placeholder2 = tf.placeholder(tf.float32, [None, 1])
            input_pt2 = prettytensor.wrap_sequence([placeholder2])
            if cell_type == 'lstm':
                output2, _ = (input_pt2.sequence_lstm(
                    4).squash_sequence().softmax_classifier(2))
            elif cell_type == 'gru':
                output2, _ = (input_pt2.sequence_gru(
                    4).squash_sequence().softmax_classifier(2))
            self.assertRaises(ValueError, recurrent_runner.run, [output2.name],
                              None, self.sess)

        # Run with a batch size of 3; first and third input are identical and must
        # yield identical output, and the same output as in the single batch run
        # above (up to floating point rounding errors).
        recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=3)
        for t in xrange(10):
            outs = recurrent_runner.run(
                [output.name],
                {placeholder.name: numpy.array([[1.2], [3.4], [1.2]])},
                sess=self.sess)
            out = outs[0]
            self.assertEqual(3, len(out))
            self.assertEqual(2, len(out[0]))
            testing.assert_allclose(out[0], out[2], rtol=TOLERANCE)
            testing.assert_allclose(out[0], out_orig[t], rtol=TOLERANCE)
            # Sanity check to protect against trivial outputs that might hide errors.
            # Need to avoid checking after t = 2 since untrained GRUs have a
            # tendency to converge to large state values, leading to outputs like
            # 1.0, 0.0.
            if cell_type == 'gru' and t > 2:
                continue
            self.assertFalse((out[0] == out[1]).all())