def setUp(self): # Set up simple model in which the ground-truth data is a tensor of ones and # the predicted data is a tensor of zeros. super().setUp() # Sets up the TensorFlow environment, so call it early. self.sess = tf.keras.backend.get_session() inputs = tf.keras.Input(1) outputs = dynamics.TrainingStepCounter()(inputs) self.model = tf.keras.Model(inputs=inputs, outputs=outputs) self.model.compile('sgd', 'mse')
def testMultipleCallsAreCountedOnce(self): """Calling the same layer twice should not increase the counter twice.""" # Create a model that calls the same TrainingStepCounter layer twice: counter = dynamics.TrainingStepCounter() inputs = tf.keras.Input(1) output1 = counter(inputs) output2 = counter(inputs) model = tf.keras.Model(inputs=inputs, outputs=[output1, output2]) model.compile('sgd', 'mse') # Train: num_epochs = 2 steps_per_epoch = 5 model.fit(x=np.zeros(1), y=[np.zeros(1), np.zeros(1)], epochs=num_epochs, steps_per_epoch=steps_per_epoch) step_count = self.sess.run(model.layers[-1].weights[0]) self.assertEqual(step_count, num_epochs * steps_per_epoch)