예제 #1
0
  def test_spec(self, specs, disjoint_init_and_next):
    """Test compatibility with different structures of variables."""
    optimizer_fn = lambda: tf.keras.optimizers.SGD(0.1)
    variables = tf.nest.map_structure(
        lambda s: tf.Variable(tf.ones(s.shape, s.dtype)), specs)
    gradients = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype),
                                      specs)

    optimizer = keras_optimizer.KerasOptimizer(
        optimizer_fn, variables, disjoint_init_and_next=disjoint_init_and_next)
    state = optimizer.initialize(specs)
    for _ in range(3):
      state, variables = optimizer.next(state, variables, gradients)

    expected_variables = tf.nest.map_structure(
        lambda s: 0.7 * tf.ones(s.shape, s.dtype), specs)
    self.assertAllClose(expected_variables, variables)
예제 #2
0
 def next_fn(state, initial_weights):
     variables = tf.Variable(initial_weights)
     optimizer = keras_optimizer.KerasOptimizer(
         optimizer_fn, variables, disjoint_init_and_next=True)
     return single_step(optimizer, state, variables)
예제 #3
0
 def initialize_fn():
     variables = tf.Variable(tf.zeros([5, 1]))
     optimizer = keras_optimizer.KerasOptimizer(
         optimizer_fn, variables, disjoint_init_and_next=True)
     return optimizer.initialize(
         tf.TensorSpec(variables.shape, variables.dtype))
예제 #4
0
 def test_computation(initial_weights):
     variables = tf.Variable(initial_weights)
     optimizer = keras_optimizer.KerasOptimizer(
         optimizer_fn, variables, disjoint_init_and_next=False)
     return training_loop(optimizer, variables)