示例#1
0
def transformer_revnet_encoder(encoder_input,
                               encoder_self_attention_bias,
                               hparams,
                               name="encoder"):
    """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string

  Returns:
    y: a Tensors
  """
    def f(x, side_input):
        """f(x) for reversible layer, self-attention layer."""
        encoder_self_attention_bias = side_input[0]

        old_hid_size = hparams.hidden_size
        hparams.hidden_size = old_hid_size // 2

        with tf.variable_scope("self_attention"):
            y = common_attention.multihead_attention(
                common_layers.layer_preprocess(x, hparams), None,
                encoder_self_attention_bias, hparams.attention_key_channels
                or hparams.hidden_size, hparams.attention_value_channels
                or hparams.hidden_size, hparams.hidden_size, hparams.num_heads,
                hparams.attention_dropout)
            y = common_layers.layer_postprocess(x, y, hparams)
        hparams.hidden_size = old_hid_size
        return y

    def g(x):
        """g(x) for reversible layer, feed-forward layer."""
        old_hid_size = hparams.hidden_size
        hparams.hidden_size = old_hid_size // 2

        with tf.variable_scope("ffn"):
            y = transformer.transformer_ffn_layer(
                common_layers.layer_preprocess(x, hparams), hparams)
            y = common_layers.layer_postprocess(x, y, hparams)
        hparams.hidden_size = old_hid_size
        return y

    x1, x2 = tf.split(encoder_input, 2, axis=-1)

    with tf.variable_scope(name):
        y1, y2 = rev_block.rev_block(
            x1,
            x2,
            f,
            g,
            num_layers=hparams.num_hidden_layers,
            f_side_input=[encoder_self_attention_bias],
            is_training=hparams.mode == tf.estimator.ModeKeys.TRAIN)
        y = tf.concat([y1, y2], axis=-1)

    return common_layers.layer_preprocess(y, hparams)
示例#2
0
def unit(x1,
         x2,
         block_num,
         depth1,
         depth2,
         num_layers,
         dim='2d',
         first_batch_norm=True,
         stride=1,
         training=True):
    """Implements bottleneck RevNet unit from authors' RevNet-104 architecture.

  Args:
    x1: [N, H, W, C] tensor of network activations.
    x2: [N, H, W, C] tensor of network activations.
    block_num: integer ID of block
    depth1: First depth in bottleneck residual unit.
    depth2: Second depth in bottleneck residual unit.
    num_layers: Number of layers in the RevNet block.
    dim: '2d' if 2-dimensional, '3d' if 3-dimensional.
    first_batch_norm: Whether to keep the first batch norm layer or not.
      Typically used in the first RevNet block.
    stride: Stride for the residual function.
    training: True for train phase, False for eval phase.

  Returns:
    Two [N, H, W, C] output activation tensors.
  """
    scope_name = 'unit_%d' % block_num
    with tf.variable_scope(scope_name):
        # Manual implementation of downsampling
        with tf.variable_scope('downsampling'):
            with tf.variable_scope('x1'):
                hx1 = h(x1, depth2, dim=dim, layer_stride=stride)
                fx2 = f(x2,
                        depth1,
                        depth2,
                        dim=dim,
                        layer_stride=stride,
                        first_batch_norm=first_batch_norm,
                        training=training)
                x1 = hx1 + fx2
            with tf.variable_scope('x2'):
                hx2 = h(x2, depth2, dim=dim, layer_stride=stride)
                fx1 = f(x1, depth1, depth2, dim=dim, training=training)
                x2 = hx2 + fx1

        # Full block using memory-efficient rev_block implementation.
        with tf.variable_scope('full_block'):
            residual_func = lambda x: f(
                x, depth1, depth2, dim=dim, training=training)
            x1, x2 = rev_block.rev_block(x1,
                                         x2,
                                         residual_func,
                                         residual_func,
                                         num_layers=num_layers)
            return x1, x2
示例#3
0
def unit(x1, x2, block_num, depth, num_layers, dim='2d',
         bottleneck=True, first_batch_norm=True, stride=1, training=True):
  """Implements bottleneck RevNet unit from authors' RevNet architecture.

  Args:
    x1: [N, H, W, C] tensor of network activations.
    x2: [N, H, W, C] tensor of network activations.
    block_num: integer ID of block
    depth: First depth in bottleneck residual unit.
    num_layers: Number of layers in the RevNet block.
    dim: '2d' if 2-dimensional, '3d' if 3-dimensional.
    bottleneck: Should a bottleneck layer be used.
    first_batch_norm: Whether to keep the first batch norm layer or not.
      Typically used in the first RevNet block.
    stride: Stride for the residual function.
    training: True for train phase, False for eval phase.

  Returns:
    Two [N, H, W, C] output activation tensors.
  """
  scope_name = 'unit_%d' % block_num
  if bottleneck:
    depth1 = depth
    depth2 = depth * 4
  else:
    depth1 = depth2 = depth

  residual = functools.partial(f,
                               depth1=depth1, depth2=depth2, dim=dim,
                               training=training, bottleneck=bottleneck)

  with tf.variable_scope(scope_name):
    downsample = downsample_bottleneck if bottleneck else downsample_residual
    # Manual implementation of downsampling
    with tf.variable_scope('downsampling'):
      with tf.variable_scope('x1'):
        hx1 = downsample(x1, depth2, dim=dim, stride=stride)
        fx2 = residual(x2, stride=stride, first_batch_norm=first_batch_norm)
        x1 = hx1 + fx2
      with tf.variable_scope('x2'):
        hx2 = downsample(x2, depth2, dim=dim, stride=stride)
        fx1 = residual(x1)
        x2 = hx2 + fx1

    # Full block using memory-efficient rev_block implementation.
    with tf.variable_scope('full_block'):
      x1, x2 = rev_block.rev_block(x1, x2,
                                   residual,
                                   residual,
                                   num_layers=num_layers)
      return x1, x2
示例#4
0
def unit(x1, x2, block_num, depth, num_layers, dim='2d',
         bottleneck=True, first_batch_norm=True, stride=1, training=True):
  """Implements bottleneck RevNet unit from authors' RevNet architecture.

  Args:
    x1: [N, H, W, C] tensor of network activations.
    x2: [N, H, W, C] tensor of network activations.
    block_num: integer ID of block
    depth: First depth in bottleneck residual unit.
    num_layers: Number of layers in the RevNet block.
    dim: '2d' if 2-dimensional, '3d' if 3-dimensional.
    bottleneck: Should a bottleneck layer be used.
    first_batch_norm: Whether to keep the first batch norm layer or not.
      Typically used in the first RevNet block.
    stride: Stride for the residual function.
    training: True for train phase, False for eval phase.

  Returns:
    Two [N, H, W, C] output activation tensors.
  """
  scope_name = 'unit_%d' % block_num
  if bottleneck:
    depth1 = depth
    depth2 = depth * 4
  else:
    depth1 = depth2 = depth

  residual = functools.partial(f,
                               depth1=depth1, depth2=depth2, dim=dim,
                               training=training, bottleneck=bottleneck)

  with tf.variable_scope(scope_name):
    downsample = downsample_bottleneck if bottleneck else downsample_residual
    # Manual implementation of downsampling
    with tf.variable_scope('downsampling'):
      with tf.variable_scope('x1'):
        hx1 = downsample(x1, depth2, dim=dim, stride=stride)
        fx2 = residual(x2, stride=stride, first_batch_norm=first_batch_norm)
        x1 = hx1 + fx2
      with tf.variable_scope('x2'):
        hx2 = downsample(x2, depth2, dim=dim, stride=stride)
        fx1 = residual(x1)
        x2 = hx2 + fx1

    # Full block using memory-efficient rev_block implementation.
    with tf.variable_scope('full_block'):
      x1, x2 = rev_block.rev_block(x1, x2,
                                   residual,
                                   residual,
                                   num_layers=num_layers)
      return x1, x2
示例#5
0
  def testRevBlock(self):
    channels = 8
    num_layers = 4
    batch_size = 16
    tf.set_random_seed(1234)

    def f(x):
      return tf.layers.dense(x, channels // 2, use_bias=True)

    def g(x):
      return tf.layers.dense(x, channels // 2, use_bias=True)

    x = tf.random_uniform([batch_size, channels], dtype=tf.float32)

    with tf.variable_scope("defun") as vs:
      y_defun = rev_block.rev_block(x, f, g, num_layers=num_layers)
      fg_vars = vs.trainable_variables()

    num_vars = len(tf.global_variables())
    with tf.variable_scope(vs, reuse=True):
      y = rev_block.rev_block(x, f, g, num_layers=num_layers, is_training=False)
    # Ensure no new vars were created - full reuse
    assert len(tf.global_variables()) == num_vars

    loss_defun = tf.reduce_mean(y_defun + 10.)
    loss = tf.reduce_mean(y + 10.)

    grads_defun = tf.gradients(loss_defun, [x] + fg_vars)
    grads = tf.gradients(loss, [x] + fg_vars)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      y_val, yd_val, gd_val, g_val = sess.run([y, y_defun, grads_defun, grads])
      self.assertAllClose(y_val, yd_val)
      for g1, g2 in zip(gd_val, g_val):
        self.assertAllClose(g1, g2)
示例#6
0
def unit(x1, x2, block_num, depth1, depth2, num_layers, dim='2d',
         first_batch_norm=True, stride=1, training=True):
  """Implements bottleneck RevNet unit from authors' RevNet-104 architecture.

  Args:
    x1: [N, H, W, C] tensor of network activations.
    x2: [N, H, W, C] tensor of network activations.
    block_num: integer ID of block
    depth1: First depth in bottleneck residual unit.
    depth2: Second depth in bottleneck residual unit.
    num_layers: Number of layers in the RevNet block.
    dim: '2d' if 2-dimensional, '3d' if 3-dimensional.
    first_batch_norm: Whether to keep the first batch norm layer or not.
      Typically used in the first RevNet block.
    stride: Stride for the residual function.
    training: True for train phase, False for eval phase.

  Returns:
    Two [N, H, W, C] output activation tensors.
  """
  scope_name = 'unit_%d' % block_num
  with tf.variable_scope(scope_name):
    # Manual implementation of downsampling
    with tf.variable_scope('downsampling'):
      with tf.variable_scope('x1'):
        hx1 = h(x1, depth2, dim=dim, layer_stride=stride)
        fx2 = f(x2, depth1, depth2, dim=dim, layer_stride=stride,
                first_batch_norm=first_batch_norm, training=training)
        x1 = hx1 + fx2
      with tf.variable_scope('x2'):
        hx2 = h(x2, depth2, dim=dim, layer_stride=stride)
        fx1 = f(x1, depth1, depth2, dim=dim, training=training)
        x2 = hx2 + fx1

    # Full block using memory-efficient rev_block implementation.
    with tf.variable_scope('full_block'):
      residual_func = lambda x: f(x, depth1, depth2, dim=dim, training=training)
      x1, x2 = rev_block.rev_block(x1, x2,
                                   residual_func,
                                   residual_func,
                                   num_layers=num_layers)
      return x1, x2
示例#7
0
  def testSmoke(self):
    channels = 8
    num_layers = 4
    batch_size = 16
    use_defun = True
    tf.set_random_seed(1234)

    def f(x):
      return tf.layers.dense(x, channels // 2, use_bias=True)

    def g(x):
      return tf.layers.dense(x, channels // 2, use_bias=True)

    x = tf.random_uniform([batch_size, channels], dtype=tf.float32)
    y = rev_block.rev_block(
        x, f, g, num_layers=num_layers, is_training=use_defun)
    loss = tf.reduce_mean(y + 10.)
    grads = tf.gradients(loss, [x] + tf.global_variables())
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      _ = sess.run(grads)
    def _test_rev_block(self,
                        x=None,
                        f=None,
                        g=None,
                        f_side_input=None,
                        g_side_input=None):
        tf.set_random_seed(1234)

        if f is None:

            def f(x):  # pylint: disable=function-redefined
                return tf.layers.dense(x, self.CHANNELS // 2, use_bias=True)

        if g is None:

            def g(x):  # pylint: disable=function-redefined
                return tf.layers.dense(x, self.CHANNELS // 2, use_bias=True)

        if f_side_input is None:
            f_side_input = []

        if g_side_input is None:
            g_side_input = []

        if x is None:
            x = tf.random_uniform([self.BATCH_SIZE, self.CHANNELS],
                                  dtype=tf.float32)
        x1, x2 = tf.split(x, 2, axis=-1)

        with tf.variable_scope("rev_test") as vs:
            y1_rev, y2_rev = rev_block.rev_block(x1,
                                                 x2,
                                                 f,
                                                 g,
                                                 f_side_input=f_side_input,
                                                 g_side_input=g_side_input,
                                                 num_layers=self.NUM_LAYERS)
            y_rev = tf.concat([y1_rev, y2_rev], axis=1)
            fg_vars = vs.trainable_variables()

        num_vars = len(tf.global_variables())
        with tf.variable_scope(vs, reuse=True):
            y1, y2 = rev_block.rev_block(x1,
                                         x2,
                                         f,
                                         g,
                                         f_side_input=f_side_input,
                                         g_side_input=g_side_input,
                                         num_layers=self.NUM_LAYERS,
                                         is_training=False)
            y = tf.concat([y1, y2], axis=1)
        # Ensure no new vars were created - full reuse
        assert len(tf.global_variables()) == num_vars

        loss_rev = tf.reduce_mean(y_rev + 10.)
        loss = tf.reduce_mean(y + 10.)

        wrt = [x] + f_side_input + g_side_input + fg_vars
        grads_rev = tf.gradients(loss_rev, wrt)
        grads = tf.gradients(loss, wrt)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            y_val, yd_val, gd_val, g_val = sess.run(
                [y, y_rev, grads_rev, grads])
            self.assertAllClose(y_val, yd_val)
            for g1, g2 in zip(gd_val, g_val):
                self.assertAllClose(g1, g2)
def transformer_revnet_encoder(encoder_input,
                               encoder_self_attention_bias,
                               hparams,
                               name="encoder"):
  """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string

  Returns:
    y: a Tensors
  """

  def f(x, side_input):
    """f(x) for reversible layer, self-attention layer."""
    encoder_self_attention_bias = side_input[0]

    old_hid_size = hparams.hidden_size
    hparams.hidden_size = old_hid_size // 2

    with tf.variable_scope("self_attention"):
      y = common_attention.multihead_attention(
          common_layers.layer_preprocess(
              x, hparams), None, encoder_self_attention_bias,
          hparams.attention_key_channels or hparams.hidden_size,
          hparams.attention_value_channels or hparams.hidden_size,
          hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
      y = common_layers.layer_postprocess(x, y, hparams)
    hparams.hidden_size = old_hid_size
    return y

  def g(x):
    """g(x) for reversible layer, feed-forward layer."""
    old_hid_size = hparams.hidden_size
    hparams.hidden_size = old_hid_size // 2

    with tf.variable_scope("ffn"):
      y = transformer.transformer_ffn_layer(
          common_layers.layer_preprocess(x, hparams), hparams)
      y = common_layers.layer_postprocess(x, y, hparams)
    hparams.hidden_size = old_hid_size
    return y

  x1, x2 = tf.split(encoder_input, 2, axis=-1)

  with tf.variable_scope(name):
    y1, y2 = rev_block.rev_block(
        x1,
        x2,
        f,
        g,
        num_layers=hparams.num_hidden_layers,
        f_side_input=[encoder_self_attention_bias],
        is_training=hparams.mode == tf.estimator.ModeKeys.TRAIN)
    y = tf.concat([y1, y2], axis=-1)

  return common_layers.layer_preprocess(y, hparams)
示例#10
0
  def _test_rev_block(self,
                      x=None,
                      f=None,
                      g=None,
                      f_side_input=None,
                      g_side_input=None):
    tf.set_random_seed(1234)

    if f is None:

      def f(x):  # pylint: disable=function-redefined
        return tf.layers.dense(x, self.CHANNELS // 2, use_bias=True)

    if g is None:

      def g(x):  # pylint: disable=function-redefined
        return tf.layers.dense(x, self.CHANNELS // 2, use_bias=True)

    if f_side_input is None:
      f_side_input = []

    if g_side_input is None:
      g_side_input = []

    if x is None:
      x = tf.random_uniform([self.BATCH_SIZE, self.CHANNELS], dtype=tf.float32)
    x1, x2 = tf.split(x, 2, axis=-1)

    with tf.variable_scope("rev_test") as vs:
      y1_rev, y2_rev = rev_block.rev_block(
          x1,
          x2,
          f,
          g,
          f_side_input=f_side_input,
          g_side_input=g_side_input,
          num_layers=self.NUM_LAYERS)
      y_rev = tf.concat([y1_rev, y2_rev], axis=1)
      fg_vars = vs.trainable_variables()

    num_vars = len(tf.global_variables())
    with tf.variable_scope(vs, reuse=True):
      y1, y2 = rev_block.rev_block(
          x1,
          x2,
          f,
          g,
          f_side_input=f_side_input,
          g_side_input=g_side_input,
          num_layers=self.NUM_LAYERS,
          is_training=False)
      y = tf.concat([y1, y2], axis=1)
    # Ensure no new vars were created - full reuse
    assert len(tf.global_variables()) == num_vars

    loss_rev = tf.reduce_mean(y_rev + 10.)
    loss = tf.reduce_mean(y + 10.)

    wrt = [x] + f_side_input + g_side_input + fg_vars
    grads_rev = tf.gradients(loss_rev, wrt)
    grads = tf.gradients(loss, wrt)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
      self.assertAllClose(y_val, yd_val)
      for g1, g2 in zip(gd_val, g_val):
        self.assertAllClose(g1, g2)