def transformer_revnet_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder"): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ def f(x, side_input): """f(x) for reversible layer, self-attention layer.""" encoder_self_attention_bias = side_input[0] old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y def g(x): """g(x) for reversible layer, feed-forward layer.""" old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y x1, x2 = tf.split(encoder_input, 2, axis=-1) with tf.variable_scope(name): y1, y2 = rev_block.rev_block( x1, x2, f, g, num_layers=hparams.num_hidden_layers, f_side_input=[encoder_self_attention_bias], is_training=hparams.mode == tf.estimator.ModeKeys.TRAIN) y = tf.concat([y1, y2], axis=-1) return common_layers.layer_preprocess(y, hparams)
def unit(x1, x2, block_num, depth1, depth2, num_layers, dim='2d', first_batch_norm=True, stride=1, training=True): """Implements bottleneck RevNet unit from authors' RevNet-104 architecture. Args: x1: [N, H, W, C] tensor of network activations. x2: [N, H, W, C] tensor of network activations. block_num: integer ID of block depth1: First depth in bottleneck residual unit. depth2: Second depth in bottleneck residual unit. num_layers: Number of layers in the RevNet block. dim: '2d' if 2-dimensional, '3d' if 3-dimensional. first_batch_norm: Whether to keep the first batch norm layer or not. Typically used in the first RevNet block. stride: Stride for the residual function. training: True for train phase, False for eval phase. Returns: Two [N, H, W, C] output activation tensors. """ scope_name = 'unit_%d' % block_num with tf.variable_scope(scope_name): # Manual implementation of downsampling with tf.variable_scope('downsampling'): with tf.variable_scope('x1'): hx1 = h(x1, depth2, dim=dim, layer_stride=stride) fx2 = f(x2, depth1, depth2, dim=dim, layer_stride=stride, first_batch_norm=first_batch_norm, training=training) x1 = hx1 + fx2 with tf.variable_scope('x2'): hx2 = h(x2, depth2, dim=dim, layer_stride=stride) fx1 = f(x1, depth1, depth2, dim=dim, training=training) x2 = hx2 + fx1 # Full block using memory-efficient rev_block implementation. with tf.variable_scope('full_block'): residual_func = lambda x: f( x, depth1, depth2, dim=dim, training=training) x1, x2 = rev_block.rev_block(x1, x2, residual_func, residual_func, num_layers=num_layers) return x1, x2
def unit(x1, x2, block_num, depth, num_layers, dim='2d', bottleneck=True, first_batch_norm=True, stride=1, training=True): """Implements bottleneck RevNet unit from authors' RevNet architecture. Args: x1: [N, H, W, C] tensor of network activations. x2: [N, H, W, C] tensor of network activations. block_num: integer ID of block depth: First depth in bottleneck residual unit. num_layers: Number of layers in the RevNet block. dim: '2d' if 2-dimensional, '3d' if 3-dimensional. bottleneck: Should a bottleneck layer be used. first_batch_norm: Whether to keep the first batch norm layer or not. Typically used in the first RevNet block. stride: Stride for the residual function. training: True for train phase, False for eval phase. Returns: Two [N, H, W, C] output activation tensors. """ scope_name = 'unit_%d' % block_num if bottleneck: depth1 = depth depth2 = depth * 4 else: depth1 = depth2 = depth residual = functools.partial(f, depth1=depth1, depth2=depth2, dim=dim, training=training, bottleneck=bottleneck) with tf.variable_scope(scope_name): downsample = downsample_bottleneck if bottleneck else downsample_residual # Manual implementation of downsampling with tf.variable_scope('downsampling'): with tf.variable_scope('x1'): hx1 = downsample(x1, depth2, dim=dim, stride=stride) fx2 = residual(x2, stride=stride, first_batch_norm=first_batch_norm) x1 = hx1 + fx2 with tf.variable_scope('x2'): hx2 = downsample(x2, depth2, dim=dim, stride=stride) fx1 = residual(x1) x2 = hx2 + fx1 # Full block using memory-efficient rev_block implementation. with tf.variable_scope('full_block'): x1, x2 = rev_block.rev_block(x1, x2, residual, residual, num_layers=num_layers) return x1, x2
def testRevBlock(self): channels = 8 num_layers = 4 batch_size = 16 tf.set_random_seed(1234) def f(x): return tf.layers.dense(x, channels // 2, use_bias=True) def g(x): return tf.layers.dense(x, channels // 2, use_bias=True) x = tf.random_uniform([batch_size, channels], dtype=tf.float32) with tf.variable_scope("defun") as vs: y_defun = rev_block.rev_block(x, f, g, num_layers=num_layers) fg_vars = vs.trainable_variables() num_vars = len(tf.global_variables()) with tf.variable_scope(vs, reuse=True): y = rev_block.rev_block(x, f, g, num_layers=num_layers, is_training=False) # Ensure no new vars were created - full reuse assert len(tf.global_variables()) == num_vars loss_defun = tf.reduce_mean(y_defun + 10.) loss = tf.reduce_mean(y + 10.) grads_defun = tf.gradients(loss_defun, [x] + fg_vars) grads = tf.gradients(loss, [x] + fg_vars) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) y_val, yd_val, gd_val, g_val = sess.run([y, y_defun, grads_defun, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): self.assertAllClose(g1, g2)
def unit(x1, x2, block_num, depth1, depth2, num_layers, dim='2d', first_batch_norm=True, stride=1, training=True): """Implements bottleneck RevNet unit from authors' RevNet-104 architecture. Args: x1: [N, H, W, C] tensor of network activations. x2: [N, H, W, C] tensor of network activations. block_num: integer ID of block depth1: First depth in bottleneck residual unit. depth2: Second depth in bottleneck residual unit. num_layers: Number of layers in the RevNet block. dim: '2d' if 2-dimensional, '3d' if 3-dimensional. first_batch_norm: Whether to keep the first batch norm layer or not. Typically used in the first RevNet block. stride: Stride for the residual function. training: True for train phase, False for eval phase. Returns: Two [N, H, W, C] output activation tensors. """ scope_name = 'unit_%d' % block_num with tf.variable_scope(scope_name): # Manual implementation of downsampling with tf.variable_scope('downsampling'): with tf.variable_scope('x1'): hx1 = h(x1, depth2, dim=dim, layer_stride=stride) fx2 = f(x2, depth1, depth2, dim=dim, layer_stride=stride, first_batch_norm=first_batch_norm, training=training) x1 = hx1 + fx2 with tf.variable_scope('x2'): hx2 = h(x2, depth2, dim=dim, layer_stride=stride) fx1 = f(x1, depth1, depth2, dim=dim, training=training) x2 = hx2 + fx1 # Full block using memory-efficient rev_block implementation. with tf.variable_scope('full_block'): residual_func = lambda x: f(x, depth1, depth2, dim=dim, training=training) x1, x2 = rev_block.rev_block(x1, x2, residual_func, residual_func, num_layers=num_layers) return x1, x2
def testSmoke(self): channels = 8 num_layers = 4 batch_size = 16 use_defun = True tf.set_random_seed(1234) def f(x): return tf.layers.dense(x, channels // 2, use_bias=True) def g(x): return tf.layers.dense(x, channels // 2, use_bias=True) x = tf.random_uniform([batch_size, channels], dtype=tf.float32) y = rev_block.rev_block( x, f, g, num_layers=num_layers, is_training=use_defun) loss = tf.reduce_mean(y + 10.) grads = tf.gradients(loss, [x] + tf.global_variables()) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) _ = sess.run(grads)
def _test_rev_block(self, x=None, f=None, g=None, f_side_input=None, g_side_input=None): tf.set_random_seed(1234) if f is None: def f(x): # pylint: disable=function-redefined return tf.layers.dense(x, self.CHANNELS // 2, use_bias=True) if g is None: def g(x): # pylint: disable=function-redefined return tf.layers.dense(x, self.CHANNELS // 2, use_bias=True) if f_side_input is None: f_side_input = [] if g_side_input is None: g_side_input = [] if x is None: x = tf.random_uniform([self.BATCH_SIZE, self.CHANNELS], dtype=tf.float32) x1, x2 = tf.split(x, 2, axis=-1) with tf.variable_scope("rev_test") as vs: y1_rev, y2_rev = rev_block.rev_block(x1, x2, f, g, f_side_input=f_side_input, g_side_input=g_side_input, num_layers=self.NUM_LAYERS) y_rev = tf.concat([y1_rev, y2_rev], axis=1) fg_vars = vs.trainable_variables() num_vars = len(tf.global_variables()) with tf.variable_scope(vs, reuse=True): y1, y2 = rev_block.rev_block(x1, x2, f, g, f_side_input=f_side_input, g_side_input=g_side_input, num_layers=self.NUM_LAYERS, is_training=False) y = tf.concat([y1, y2], axis=1) # Ensure no new vars were created - full reuse assert len(tf.global_variables()) == num_vars loss_rev = tf.reduce_mean(y_rev + 10.) loss = tf.reduce_mean(y + 10.) wrt = [x] + f_side_input + g_side_input + fg_vars grads_rev = tf.gradients(loss_rev, wrt) grads = tf.gradients(loss, wrt) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) y_val, yd_val, gd_val, g_val = sess.run( [y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): self.assertAllClose(g1, g2)
def transformer_revnet_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder"): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ def f(x, side_input): """f(x) for reversible layer, self-attention layer.""" encoder_self_attention_bias = side_input[0] old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y def g(x): """g(x) for reversible layer, feed-forward layer.""" old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y x1, x2 = tf.split(encoder_input, 2, axis=-1) with tf.variable_scope(name): y1, y2 = rev_block.rev_block( x1, x2, f, g, num_layers=hparams.num_hidden_layers, f_side_input=[encoder_self_attention_bias], is_training=hparams.mode == tf.estimator.ModeKeys.TRAIN) y = tf.concat([y1, y2], axis=-1) return common_layers.layer_preprocess(y, hparams)
def _test_rev_block(self, x=None, f=None, g=None, f_side_input=None, g_side_input=None): tf.set_random_seed(1234) if f is None: def f(x): # pylint: disable=function-redefined return tf.layers.dense(x, self.CHANNELS // 2, use_bias=True) if g is None: def g(x): # pylint: disable=function-redefined return tf.layers.dense(x, self.CHANNELS // 2, use_bias=True) if f_side_input is None: f_side_input = [] if g_side_input is None: g_side_input = [] if x is None: x = tf.random_uniform([self.BATCH_SIZE, self.CHANNELS], dtype=tf.float32) x1, x2 = tf.split(x, 2, axis=-1) with tf.variable_scope("rev_test") as vs: y1_rev, y2_rev = rev_block.rev_block( x1, x2, f, g, f_side_input=f_side_input, g_side_input=g_side_input, num_layers=self.NUM_LAYERS) y_rev = tf.concat([y1_rev, y2_rev], axis=1) fg_vars = vs.trainable_variables() num_vars = len(tf.global_variables()) with tf.variable_scope(vs, reuse=True): y1, y2 = rev_block.rev_block( x1, x2, f, g, f_side_input=f_side_input, g_side_input=g_side_input, num_layers=self.NUM_LAYERS, is_training=False) y = tf.concat([y1, y2], axis=1) # Ensure no new vars were created - full reuse assert len(tf.global_variables()) == num_vars loss_rev = tf.reduce_mean(y_rev + 10.) loss = tf.reduce_mean(y + 10.) wrt = [x] + f_side_input + g_side_input + fg_vars grads_rev = tf.gradients(loss_rev, wrt) grads = tf.gradients(loss, wrt) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): self.assertAllClose(g1, g2)