def testReuse(self): def f(x): return core_layers.dense(x, self.CHANNELS // 2) def g(x): return core_layers.dense(x, self.CHANNELS // 2) x = random_ops.random_uniform( [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) x1, x2 = array_ops.split(x, 2, axis=-1) with variable_scope.variable_scope("test"): y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) num_vars_before = len(variables.global_variables()) with variable_scope.variable_scope("test", reuse=True): y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) num_vars_after = len(variables.global_variables()) self.assertEqual(num_vars_before, num_vars_after) loss = math_ops.reduce_mean(y1 + y2) _ = gradients_impl.gradients(loss, [x] + variables.trainable_variables()) with variable_scope.variable_scope("test", reuse=True): y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) num_vars_after = len(variables.global_variables()) self.assertEqual(num_vars_before, num_vars_after)
def _testRevBlock(self, x=None, f=None, g=None, f_side_input=None, g_side_input=None): random_seed.set_random_seed(1234) if f is None: def f(x): # pylint: disable=function-redefined return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) if g is None: def g(x): # pylint: disable=function-redefined return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) if f_side_input is None: f_side_input = [] if g_side_input is None: g_side_input = [] if x is None: x = random_ops.random_uniform( [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) x1, x2 = array_ops.split(x, 2, axis=-1) with variable_scope.variable_scope("rev_test") as vs: y1_rev, y2_rev = rev_block_lib.rev_block( x1, x2, f, g, f_side_input=f_side_input, g_side_input=g_side_input, num_layers=self.NUM_LAYERS) y_rev = array_ops.concat([y1_rev, y2_rev], axis=1) fg_vars = vs.trainable_variables() num_vars = len(variables.global_variables()) with variable_scope.variable_scope(vs, reuse=True): y1, y2 = rev_block_lib.rev_block( x1, x2, f, g, f_side_input=f_side_input, g_side_input=g_side_input, num_layers=self.NUM_LAYERS, is_training=False) y = array_ops.concat([y1, y2], axis=1) # Ensure no new vars were created - full reuse assert len(variables.global_variables()) == num_vars loss_rev = math_ops.reduce_mean(y_rev + 10.) loss = math_ops.reduce_mean(y + 10.) wrt = [x] + f_side_input + g_side_input + fg_vars grads_rev = gradients_impl.gradients(loss_rev, wrt) grads = gradients_impl.gradients(loss, wrt) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): self.assertAllClose(g1, g2, rtol=1e-5)
def _testRevBlock(self, x=None, f=None, g=None, f_side_input=None, g_side_input=None): random_seed.set_random_seed(1234) if f is None: def f(x): # pylint: disable=function-redefined return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) if g is None: def g(x): # pylint: disable=function-redefined return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) if f_side_input is None: f_side_input = [] if g_side_input is None: g_side_input = [] if x is None: x = random_ops.random_uniform([self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) x1, x2 = array_ops.split(x, 2, axis=-1) with variable_scope.variable_scope("rev_test") as vs: y1_rev, y2_rev = rev_block_lib.rev_block( x1, x2, f, g, f_side_input=f_side_input, g_side_input=g_side_input, num_layers=self.NUM_LAYERS) y_rev = array_ops.concat([y1_rev, y2_rev], axis=1) fg_vars = vs.trainable_variables() num_vars = len(variables.global_variables()) with variable_scope.variable_scope(vs, reuse=True): y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, f_side_input=f_side_input, g_side_input=g_side_input, num_layers=self.NUM_LAYERS, is_training=False) y = array_ops.concat([y1, y2], axis=1) # Ensure no new vars were created - full reuse assert len(variables.global_variables()) == num_vars loss_rev = math_ops.reduce_mean(y_rev + 10.) loss = math_ops.reduce_mean(y + 10.) wrt = [x] + f_side_input + g_side_input + fg_vars grads_rev = gradients_impl.gradients(loss_rev, wrt) grads = gradients_impl.gradients(loss, wrt) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) y_val, yd_val, gd_val, g_val = sess.run( [y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): self.assertAllClose(g1, g2, rtol=1e-5)