Example #1
0
  def testReuse(self):

    def f(x):
      return core_layers.dense(x, self.CHANNELS // 2)

    def g(x):
      return core_layers.dense(x, self.CHANNELS // 2)

    x = random_ops.random_uniform(
        [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32)
    x1, x2 = array_ops.split(x, 2, axis=-1)

    with variable_scope.variable_scope("test"):
      y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS)

    num_vars_before = len(variables.global_variables())

    with variable_scope.variable_scope("test", reuse=True):
      y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS)

    num_vars_after = len(variables.global_variables())
    self.assertEqual(num_vars_before, num_vars_after)

    loss = math_ops.reduce_mean(y1 + y2)
    _ = gradients_impl.gradients(loss,
                                 [x] + variables.trainable_variables())

    with variable_scope.variable_scope("test", reuse=True):
      y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS)

    num_vars_after = len(variables.global_variables())
    self.assertEqual(num_vars_before, num_vars_after)
Example #2
0
  def testReuse(self):

    def f(x):
      return core_layers.dense(x, self.CHANNELS // 2)

    def g(x):
      return core_layers.dense(x, self.CHANNELS // 2)

    x = random_ops.random_uniform(
        [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32)
    x1, x2 = array_ops.split(x, 2, axis=-1)

    with variable_scope.variable_scope("test"):
      y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS)

    num_vars_before = len(variables.global_variables())

    with variable_scope.variable_scope("test", reuse=True):
      y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS)

    num_vars_after = len(variables.global_variables())
    self.assertEqual(num_vars_before, num_vars_after)

    loss = math_ops.reduce_mean(y1 + y2)
    _ = gradients_impl.gradients(loss,
                                 [x] + variables.trainable_variables())

    with variable_scope.variable_scope("test", reuse=True):
      y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS)

    num_vars_after = len(variables.global_variables())
    self.assertEqual(num_vars_before, num_vars_after)
Example #3
0
  def _testRevBlock(self,
                    x=None,
                    f=None,
                    g=None,
                    f_side_input=None,
                    g_side_input=None):
    random_seed.set_random_seed(1234)

    if f is None:

      def f(x):  # pylint: disable=function-redefined
        return core_layers.dense(x, self.CHANNELS // 2, use_bias=True)

    if g is None:

      def g(x):  # pylint: disable=function-redefined
        return core_layers.dense(x, self.CHANNELS // 2, use_bias=True)

    if f_side_input is None:
      f_side_input = []

    if g_side_input is None:
      g_side_input = []

    if x is None:
      x = random_ops.random_uniform(
          [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32)
    x1, x2 = array_ops.split(x, 2, axis=-1)

    with variable_scope.variable_scope("rev_test") as vs:
      y1_rev, y2_rev = rev_block_lib.rev_block(
          x1,
          x2,
          f,
          g,
          f_side_input=f_side_input,
          g_side_input=g_side_input,
          num_layers=self.NUM_LAYERS)
      y_rev = array_ops.concat([y1_rev, y2_rev], axis=1)
      fg_vars = vs.trainable_variables()

    num_vars = len(variables.global_variables())
    with variable_scope.variable_scope(vs, reuse=True):
      y1, y2 = rev_block_lib.rev_block(
          x1,
          x2,
          f,
          g,
          f_side_input=f_side_input,
          g_side_input=g_side_input,
          num_layers=self.NUM_LAYERS,
          is_training=False)
      y = array_ops.concat([y1, y2], axis=1)
    # Ensure no new vars were created - full reuse
    assert len(variables.global_variables()) == num_vars

    loss_rev = math_ops.reduce_mean(y_rev + 10.)
    loss = math_ops.reduce_mean(y + 10.)

    wrt = [x] + f_side_input + g_side_input + fg_vars
    grads_rev = gradients_impl.gradients(loss_rev, wrt)
    grads = gradients_impl.gradients(loss, wrt)

    with self.test_session() as sess:
      sess.run(variables.global_variables_initializer())
      y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
      self.assertAllClose(y_val, yd_val)
      for g1, g2 in zip(gd_val, g_val):
        self.assertAllClose(g1, g2, rtol=1e-5)
    def _testRevBlock(self,
                      x=None,
                      f=None,
                      g=None,
                      f_side_input=None,
                      g_side_input=None):
        random_seed.set_random_seed(1234)

        if f is None:

            def f(x):  # pylint: disable=function-redefined
                return core_layers.dense(x, self.CHANNELS // 2, use_bias=True)

        if g is None:

            def g(x):  # pylint: disable=function-redefined
                return core_layers.dense(x, self.CHANNELS // 2, use_bias=True)

        if f_side_input is None:
            f_side_input = []

        if g_side_input is None:
            g_side_input = []

        if x is None:
            x = random_ops.random_uniform([self.BATCH_SIZE, self.CHANNELS],
                                          dtype=dtypes.float32)
        x1, x2 = array_ops.split(x, 2, axis=-1)

        with variable_scope.variable_scope("rev_test") as vs:
            y1_rev, y2_rev = rev_block_lib.rev_block(
                x1,
                x2,
                f,
                g,
                f_side_input=f_side_input,
                g_side_input=g_side_input,
                num_layers=self.NUM_LAYERS)
            y_rev = array_ops.concat([y1_rev, y2_rev], axis=1)
            fg_vars = vs.trainable_variables()

        num_vars = len(variables.global_variables())
        with variable_scope.variable_scope(vs, reuse=True):
            y1, y2 = rev_block_lib.rev_block(x1,
                                             x2,
                                             f,
                                             g,
                                             f_side_input=f_side_input,
                                             g_side_input=g_side_input,
                                             num_layers=self.NUM_LAYERS,
                                             is_training=False)
            y = array_ops.concat([y1, y2], axis=1)
        # Ensure no new vars were created - full reuse
        assert len(variables.global_variables()) == num_vars

        loss_rev = math_ops.reduce_mean(y_rev + 10.)
        loss = math_ops.reduce_mean(y + 10.)

        wrt = [x] + f_side_input + g_side_input + fg_vars
        grads_rev = gradients_impl.gradients(loss_rev, wrt)
        grads = gradients_impl.gradients(loss, wrt)

        with self.test_session() as sess:
            sess.run(variables.global_variables_initializer())
            y_val, yd_val, gd_val, g_val = sess.run(
                [y, y_rev, grads_rev, grads])
            self.assertAllClose(y_val, yd_val)
            for g1, g2 in zip(gd_val, g_val):
                self.assertAllClose(g1, g2, rtol=1e-5)