コード例 #1
0
 def setUp(self):
   # Set up simple model in which the ground-truth data is a tensor of ones and
   # the predicted data is a tensor of zeros.
   super().setUp()  # Sets up the TensorFlow environment, so call it early.
   self.sess = tf.keras.backend.get_session()
   inputs = tf.keras.Input(1)
   outputs = dynamics.TrainingStepCounter()(inputs)
   self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
   self.model.compile('sgd', 'mse')
コード例 #2
0
    def testMultipleCallsAreCountedOnce(self):
        """Calling the same layer twice should not increase the counter twice."""

        # Create a model that calls the same TrainingStepCounter layer twice:
        counter = dynamics.TrainingStepCounter()
        inputs = tf.keras.Input(1)
        output1 = counter(inputs)
        output2 = counter(inputs)
        model = tf.keras.Model(inputs=inputs, outputs=[output1, output2])
        model.compile('sgd', 'mse')

        # Train:
        num_epochs = 2
        steps_per_epoch = 5
        model.fit(x=np.zeros(1),
                  y=[np.zeros(1), np.zeros(1)],
                  epochs=num_epochs,
                  steps_per_epoch=steps_per_epoch)

        step_count = self.sess.run(model.layers[-1].weights[0])
        self.assertEqual(step_count, num_epochs * steps_per_epoch)