def call(self, x, training=True): """Apply residual block to inputs.""" x1, x2 = x f_x2 = self.f(x2, training=training) x1_down = ops.downsample( x1, self.filters // 2, self.strides, axis=self.axis) x2_down = ops.downsample( x2, self.filters // 2, self.strides, axis=self.axis) y1 = f_x2 + x1_down g_y1 = self.g(y1, training=training) y2 = g_y1 + x2_down return y1, y2
def call(self, x, training=True, concat=True): """Apply residual block to inputs.""" x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) f_x2 = self.f(x2, training=training) x1_down = ops.downsample( x1, self.filters // 2, self.strides, axis=self.axis) x2_down = ops.downsample( x2, self.filters // 2, self.strides, axis=self.axis) y1 = f_x2 + x1_down g_y1 = self.g(y1, training=training) y2 = g_y1 + x2_down if not concat: # For correct backward grads return y1, y2 return tf.concat([y1, y2], axis=self.axis)
def backward_grads_and_vars(self, x, dy, training=True): """Manually compute backward gradients given input and output grads.""" with tf.GradientTape(persistent=True) as tape: x_stop = tf.stop_gradient(x) x1, x2 = tf.split(x_stop, num_or_size_splits=2, axis=self.axis) tape.watch([x1, x2]) # Stitch back x for `call` so tape records correct grads x = tf.concat([x1, x2], axis=self.axis) dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) y1, y2 = self.call(x, training=training, concat=False) x2_down = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis) grads_combined = tape.gradient(y2, [y1] + self.g.variables, output_gradients=[dy2]) dy2_y1, dg = grads_combined[0], grads_combined[1:] dy1_plus = dy2_y1 + dy1 grads_combined = tape.gradient(y1, [x1, x2] + self.f.variables, output_gradients=[dy1_plus]) dx1, dx2, df = grads_combined[0], grads_combined[1], grads_combined[2:] dx2 += tape.gradient(x2_down, [x2], output_gradients=[dy2])[0] del tape grads = df + dg vars_ = self.f.variables + self.g.variables return tf.concat([dx1, dx2], axis=self.axis), grads, vars_
def backward_grads_and_vars(self, x, dy, training=True): """Manually compute backward gradients given input and output grads.""" with tf.GradientTape(persistent=True) as tape: x_stop = tf.stop_gradient(x) x1, x2 = tf.split(x_stop, num_or_size_splits=2, axis=self.axis) tape.watch([x1, x2]) # Stitch back x for `call` so tape records correct grads x = tf.concat([x1, x2], axis=self.axis) dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) y1, y2 = self.call(x, training=training, concat=False) x2_down = ops.downsample( x2, self.filters // 2, self.strides, axis=self.axis) grads_combined = tape.gradient( y2, [y1] + self.g.variables, output_gradients=[dy2]) dy2_y1, dg = grads_combined[0], grads_combined[1:] dy1_plus = dy2_y1 + dy1 grads_combined = tape.gradient( y1, [x1, x2] + self.f.variables, output_gradients=[dy1_plus]) dx1, dx2, df = grads_combined[0], grads_combined[1], grads_combined[2:] dx2 += tape.gradient(x2_down, [x2], output_gradients=[dy2])[0] del tape grads = df + dg vars_ = self.f.variables + self.g.variables return tf.concat([dx1, dx2], axis=self.axis), grads, vars_
def call(self, x, training=True, concat=True): """Apply residual block to inputs.""" x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) f_x2 = self.f.call(x2, training=training) # TODO(lxuechen): Replace with simpler downsampling x1_down = ops.downsample( x1, self.filters // 2, self.strides, axis=self.axis) x2_down = ops.downsample( x2, self.filters // 2, self.strides, axis=self.axis) y1 = f_x2 + x1_down g_y1 = self.g.call(y1, training=training) # self.g(y1) gives pylint error y2 = g_y1 + x2_down if not concat: # Concat option needed for correct backward grads return y1, y2 return tf.concat([y1, y2], axis=self.axis)
def backward_grads_with_downsample(self, x, y, dy, training=True): """Manually compute backward gradients given input and output grads.""" # Splitting this from `backward_grads` for better readability x1, x2 = x y1, _ = y dy1, dy2 = dy with tf.GradientTape() as gtape: gtape.watch(y1) gy1 = self.g(y1, training=training) grads_combined = gtape.gradient(gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) dg = grads_combined[1:] dz1 = dy1 + grads_combined[0] # dx1 need one more step to backprop through downsample with tf.GradientTape() as x1tape: x1tape.watch(x1) z1 = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis) dx1 = x1tape.gradient(z1, x1, output_gradients=dz1) with tf.GradientTape() as ftape: ftape.watch(x2) fx2 = self.f(x2, training=training) grads_combined = ftape.gradient(fx2, [x2] + self.f.trainable_variables, output_gradients=dz1) dx2, df = grads_combined[0], grads_combined[1:] # dx2 need one more step to backprop through downsample with tf.GradientTape() as x2tape: x2tape.watch(x2) z2 = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis) dx2 += x2tape.gradient(z2, x2, output_gradients=dy2) dx = dx1, dx2 grads = df + dg return dx, grads
def call(self, x, training=True, concat=True): """Apply residual block to inputs.""" x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) f_x2 = self.f.call(x2, training=training) # TODO(lxuechen): Replace with simpler downsampling x1_down = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis) x2_down = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis) y1 = f_x2 + x1_down g_y1 = self.g.call(y1, training=training) # self.g(y1) gives pylint error y2 = g_y1 + x2_down if not concat: # Concat option needed for correct backward grads return y1, y2 return tf.concat([y1, y2], axis=self.axis)
def backward_grads_with_downsample(self, x, y, dy, training=True): """Manually compute backward gradients given input and output grads.""" # Splitting this from `backward_grads` for better readability x1, x2 = x y1, _ = y dy1, dy2 = dy with tf.GradientTape() as gtape: gtape.watch(y1) gy1 = self.g(y1, training=training) grads_combined = gtape.gradient( gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) dg = grads_combined[1:] dz1 = dy1 + grads_combined[0] # dx1 need one more step to backprop through downsample with tf.GradientTape() as x1tape: x1tape.watch(x1) z1 = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis) dx1 = x1tape.gradient(z1, x1, output_gradients=dz1) with tf.GradientTape() as ftape: ftape.watch(x2) fx2 = self.f(x2, training=training) grads_combined = ftape.gradient( fx2, [x2] + self.f.trainable_variables, output_gradients=dz1) dx2, df = grads_combined[0], grads_combined[1:] # dx2 need one more step to backprop through downsample with tf.GradientTape() as x2tape: x2tape.watch(x2) z2 = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis) dx2 += x2tape.gradient(z2, x2, output_gradients=dy2) dx = dx1, dx2 grads = df + dg return dx, grads
def test_downsample(self): """Test `possible_down_sample` function with mock object.""" batch_size = 100 # NHWC format x = tf.random_normal(shape=[batch_size, 32, 32, 3]) # HW doesn't change but number of features increased y = ops.downsample(x, filters=5, strides=(1, 1), axis=3) self.assertEqual(y.shape, [batch_size, 32, 32, 5]) # Feature map doesn't change but HW reduced y = ops.downsample(x, filters=3, strides=(2, 2), axis=3) self.assertEqual(y.shape, [batch_size, 16, 16, 3]) # Number of feature increased and HW reduced y = ops.downsample(x, filters=5, strides=(2, 2), axis=3) self.assertEqual(y.shape, [batch_size, 16, 16, 5]) # Test gradient flow x = tf.random_normal(shape=[batch_size, 32, 32, 3]) with tfe.GradientTape() as tape: tape.watch(x) y = ops.downsample(x, filters=3, strides=(1, 1)) self.assertEqual(y.shape, x.shape) dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) grad, = tape.gradient(y, [x], output_gradients=[dy]) self.assertEqual(grad.shape, x.shape) # Default NCHW format if tf.test.is_gpu_available(): x = tf.random_normal(shape=[batch_size, 3, 32, 32]) # HW doesn't change but feature map reduced y = ops.downsample(x, filters=5, strides=(1, 1)) self.assertEqual(y.shape, [batch_size, 5, 32, 32]) # Feature map doesn't change but HW reduced y = ops.downsample(x, filters=3, strides=(2, 2)) self.assertEqual(y.shape, [batch_size, 3, 16, 16]) # Both feature map and HW reduced y = ops.downsample(x, filters=5, strides=(2, 2)) self.assertEqual(y.shape, [batch_size, 5, 16, 16]) # Test gradient flow x = tf.random_normal(shape=[batch_size, 3, 32, 32]) with tfe.GradientTape() as tape: tape.watch(x) y = ops.downsample(x, filters=3, strides=(1, 1)) self.assertEqual(y.shape, x.shape) dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) grad, = tape.gradient(y, [x], output_gradients=[dy]) self.assertEqual(grad.shape, x.shape)
def test_downsample(self): """Test `possible_down_sample` function with mock object.""" batch_size = 100 # NHWC format x = tf.random_normal(shape=[batch_size, 32, 32, 3]) # HW doesn't change but number of features increased y = ops.downsample(x, filters=5, strides=(1, 1), axis=3) self.assertEqual(y.shape, [batch_size, 32, 32, 5]) # Feature map doesn't change but HW reduced y = ops.downsample(x, filters=3, strides=(2, 2), axis=3) self.assertEqual(y.shape, [batch_size, 16, 16, 3]) # Number of feature increased and HW reduced y = ops.downsample(x, filters=5, strides=(2, 2), axis=3) self.assertEqual(y.shape, [batch_size, 16, 16, 5]) # Test gradient flow x = tf.random_normal(shape=[batch_size, 32, 32, 3]) with tfe.GradientTape() as tape: tape.watch(x) y = ops.downsample(x, filters=3, strides=(1, 1)) self.assertEqual(y.shape, x.shape) dy = tf.random_normal(shape=[batch_size, 32, 32, 3]) grad, = tape.gradient(y, [x], output_gradients=[dy]) self.assertEqual(grad.shape, x.shape) # Default NCHW format if tf.test.is_gpu_available(): x = tf.random_normal(shape=[batch_size, 3, 32, 32]) # HW doesn't change but feature map reduced y = ops.downsample(x, filters=5, strides=(1, 1)) self.assertEqual(y.shape, [batch_size, 5, 32, 32]) # Feature map doesn't change but HW reduced y = ops.downsample(x, filters=3, strides=(2, 2)) self.assertEqual(y.shape, [batch_size, 3, 16, 16]) # Both feature map and HW reduced y = ops.downsample(x, filters=5, strides=(2, 2)) self.assertEqual(y.shape, [batch_size, 5, 16, 16]) # Test gradient flow x = tf.random_normal(shape=[batch_size, 3, 32, 32]) with tfe.GradientTape() as tape: tape.watch(x) y = ops.downsample(x, filters=3, strides=(1, 1)) self.assertEqual(y.shape, x.shape) dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) grad, = tape.gradient(y, [x], output_gradients=[dy]) self.assertEqual(grad.shape, x.shape)