예제 #1
0
    def test_downsampling_non_reversible_step(self, distribution):
        bsz, h, w, c = 8, 32, 32, 32
        filters = 64
        strides = 2

        input_tensor = tf.random.uniform(shape=[bsz, h, w, c])
        with distribution.scope():
            f = nn_blocks.ResidualInner(filters=filters // 2,
                                        strides=strides,
                                        batch_norm_first=True)
            g = nn_blocks.ResidualInner(filters=filters // 2,
                                        strides=1,
                                        batch_norm_first=True)
            test_layer = nn_blocks.ReversibleLayer(f, g)
            test_layer.build(input_tensor.shape)
            optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

        @tf.function
        def step_fn():
            with tf.GradientTape() as tape:
                output = test_layer(input_tensor, training=True)
            grads = tape.gradient(output, test_layer.trainable_variables)
            # Test applying gradients with optimizer works
            optimizer.apply_gradients(
                zip(grads, test_layer.trainable_variables))

            return output

        replica_output = distribution.run(step_fn)
        outputs = distribution.experimental_local_results(replica_output)

        # Assert forward pass shape
        expected_output_shape = [bsz, h // strides, w // strides, filters]
        for output in outputs:
            self.assertEqual(expected_output_shape, output.shape.as_list())
예제 #2
0
  def test_reversible_step(self, distribution):
    # Reversible layers satisfy: (a) strides = 1 (b) in_filter = out_filter
    bsz, h, w, c = 8, 32, 32, 32
    filters = c
    strides = 1

    input_tensor = tf.random.uniform(shape=[bsz, h, w, c])
    with distribution.scope():
      f = nn_blocks.ResidualInner(
          filters=filters // 2,
          strides=strides,
          batch_norm_first=False)
      g = nn_blocks.ResidualInner(
          filters=filters // 2,
          strides=1,
          batch_norm_first=False)
      test_layer = nn_blocks.ReversibleLayer(f, g)
      test_layer(input_tensor, training=False)  # init weights
      optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

    @tf.function
    def step_fn():
      with tf.GradientTape() as tape:
        output = test_layer(input_tensor, training=True)
      grads = tape.gradient(output, test_layer.trainable_variables)
      # Test applying gradients with optimizer works
      optimizer.apply_gradients(zip(grads, test_layer.trainable_variables))

      return output

    @tf.function
    def fwd():
      test_layer(input_tensor)

    distribution.run(fwd)  # Initialize variables
    prev_variables = tf.identity_n(test_layer.trainable_variables)
    replica_output = distribution.run(step_fn)
    outputs = distribution.experimental_local_results(replica_output)

    # Assert variables values have changed values
    for v0, v1 in zip(prev_variables, test_layer.trainable_variables):
      self.assertNotAllEqual(v0, v1)

    # Assert forward pass shape
    expected_output_shape = [bsz, h // strides, w // strides, filters]
    for output in outputs:
      self.assertEqual(expected_output_shape, output.shape.as_list())
예제 #3
0
    def test_shape(self, distribution):
        bsz, h, w, c = 8, 32, 32, 32
        filters = 64
        strides = 2

        input_tensor = tf.random.uniform(shape=[bsz, h, w, c])
        with distribution.scope():
            test_layer = nn_blocks.ResidualInner(filters, strides)

        output = test_layer(input_tensor)
        expected_output_shape = [bsz, h // strides, w // strides, filters]
        self.assertEqual(expected_output_shape, output.shape.as_list())