コード例 #1
0
ファイル: counters_test.py プロジェクト: tensorflow/federated
 def test_construct(self):
     m = counters.NumExamplesCounter()
     self.assertEqual(m.name, 'num_examples')
     self.assertTrue(m.stateful)
     self.assertEqual(m.dtype, tf.int64)
     self.assertLen(m.variables, 1)
     self.assertEqual(m.total, 0)
     m = counters.NumExamplesCounter('num_examples2')
     self.assertEqual(m.name, 'num_examples2')
コード例 #2
0
ファイル: mnist.py プロジェクト: tensorflow/federated
def create_simple_keras_model(learning_rate=0.1):
    """Returns an instance of `tf.Keras.Model` with just one dense layer.

  Args:
    learning_rate: The learning rate to use with the SGD optimizer.

  Returns:
    An instance of `tf.Keras.Model`.
  """
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(784, )),
        tf.keras.layers.Dense(10, tf.nn.softmax, kernel_initializer='zeros'),
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  optimizer=tf.keras.optimizers.SGD(learning_rate),
                  metrics=[
                      tf.keras.metrics.SparseCategoricalAccuracy(),
                      counters.NumExamplesCounter(),
                  ])
    return model
コード例 #3
0
ファイル: counters_test.py プロジェクト: tensorflow/federated
 def test_reset_to_zero(self):
     m = counters.NumExamplesCounter()
     self.assertGreater(m(tf.zeros([10, 1]), tf.zeros([10])), 0)
     self.assertGreater(m.total, 0)
     m.reset_state()
     self.assertEqual(m.total, 0)
コード例 #4
0
ファイル: counters_test.py プロジェクト: tensorflow/federated
 def test_update_with_sample_weight(self, batch1, batch2):
     m = counters.NumExamplesCounter()
     self.assertEqual(m(batch1, batch1), 10)
     self.assertEqual(m.total, 10)
     self.assertEqual(m.update_state(batch2, batch2), 15)
     self.assertEqual(m.total, 15)
コード例 #5
0
ファイル: counters_test.py プロジェクト: tensorflow/federated
 def test_update_computes_shape_of_first_input_arg(self):
     m = counters.NumExamplesCounter()
     self.assertEqual(m(tf.zeros([10, 1]), tf.zeros([15])), 10)
     self.assertEqual(m.total, 10)
     self.assertEqual(m.update_state(tf.zeros([5, 1]), tf.zeros([700])), 15)
     self.assertEqual(m.total, 15)
コード例 #6
0
 def metrics_fn():
     return [counters.NumExamplesCounter(), NumOverCounter(5.0)]
コード例 #7
0
 def metrics_fn():
     return [
         counters.NumExamplesCounter(),
         counters.NumBatchesCounter(),
         tf.keras.metrics.SparseCategoricalAccuracy()
     ]