def testShape(self): shape = [10, 5] gradients = tf.random_normal(shape) net = networks.Sgd() state = net.initial_state_for_inputs(gradients) update, _ = net(gradients, state) self.assertEqual(update.get_shape().as_list(), shape)
def testNonTrainable(self): """Tests the network doesn't contain trainable variables.""" shape = [10, 5] gradients = tf.random_normal(shape) net = networks.Sgd() state = net.initial_state_for_inputs(gradients) net(gradients, state) variables = nn.get_variables_in_module(net) self.assertEqual(len(variables), 0)
def testResults(self): """Tests network produces zero updates with learning rate equal to zero.""" shape = [10] learning_rate = 0.01 gradients = tf.random_normal(shape) net = networks.Sgd(learning_rate=learning_rate) state = net.initial_state_for_inputs(gradients) update, _ = net(gradients, state) with self.test_session() as sess: gradients_np, update_np = sess.run([gradients, update]) self.assertAllEqual(update_np, -learning_rate * gradients_np)