Exemplo n.º 1
0
def ff(inputs, num_units, scope="positionwise_feedforward"):
    '''position-wise feed forward net. See 3.3
    
    inputs: A 3d tensor with shape of [N, T, C].
    num_units: A list of two integers.
    scope: Optional scope for `variable_scope`.

    Returns:
      A 3d tensor with the same shape and dtype as inputs
    '''
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # Inner layer
        outputs = layers.masked_fully_connected(inputs, num_units[0])

        # Outer layer
        outputs = layers.masked_fully_connected(outputs,
                                                num_units[1],
                                                activation_fn=None)

        # Residual connection
        outputs += inputs

        # Normalize
        outputs = ln(outputs)

    return outputs
Exemplo n.º 2
0
def multihead_attention(queries,
                        keys,
                        values,
                        num_heads=8,
                        dropout_rate=0,
                        training=True,
                        causality=False,
                        scope="multihead_attention"):
    '''Applies multihead attention. See 3.2.2
    queries: A 3d tensor with shape of [N, T_q, d_model].
    keys: A 3d tensor with shape of [N, T_k, d_model].
    values: A 3d tensor with shape of [N, T_k, d_model].
    num_heads: An int. Number of heads.
    dropout_rate: A floating point number.
    training: Boolean. Controller of mechanism for dropout.
    causality: Boolean. If true, units that reference the future are masked.
    scope: Optional scope for `variable_scope`.
        
    Returns
      A 3d tensor with shape of (N, T_q, C)  
    '''
    d_model = queries.get_shape().as_list()[-1]
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # Linear projections
        Q = layers.masked_fully_connected(
            queries, d_model, biases_initializer=None,
            activation_fn=None)  # (N, T_q, d_model)
        K = layers.masked_fully_connected(
            keys, d_model, biases_initializer=None,
            activation_fn=None)  # (N, T_k, d_model)
        V = layers.masked_fully_connected(
            values, d_model, biases_initializer=None,
            activation_fn=None)  # (N, T_k, d_model)

        # Split and concat
        Q_ = tf.concat(tf.split(Q, num_heads, axis=2),
                       axis=0)  # (h*N, T_q, d_model/h)
        K_ = tf.concat(tf.split(K, num_heads, axis=2),
                       axis=0)  # (h*N, T_k, d_model/h)
        V_ = tf.concat(tf.split(V, num_heads, axis=2),
                       axis=0)  # (h*N, T_k, d_model/h)

        # Attention
        outputs = scaled_dot_product_attention(Q_, K_, V_, causality,
                                               dropout_rate, training)

        # Restore shape
        outputs = tf.concat(tf.split(outputs, num_heads, axis=0),
                            axis=2)  # (N, T_q, d_model)

        # Residual connection
        outputs += queries

        # Normalize
        outputs = ln(outputs)

    return outputs
Exemplo n.º 3
0
def mnist_network_fc(input_batch, reuse=False, model_pruning=False):
    """Define a basic FC network."""
    regularizer = contrib_layers.l2_regularizer(scale=FLAGS.l2_scale)
    if model_pruning:
        y = layers.masked_fully_connected(inputs=input_batch[0],
                                          num_outputs=300,
                                          activation_fn=tf.nn.relu,
                                          weights_regularizer=regularizer,
                                          reuse=reuse,
                                          scope='layer1')
        y1 = layers.masked_fully_connected(inputs=y,
                                           num_outputs=100,
                                           activation_fn=tf.nn.relu,
                                           weights_regularizer=regularizer,
                                           reuse=reuse,
                                           scope='layer2')
        logits = layers.masked_fully_connected(inputs=y1,
                                               num_outputs=10,
                                               reuse=reuse,
                                               activation_fn=None,
                                               weights_regularizer=regularizer,
                                               scope='layer3')
    else:
        y = tf.layers.dense(inputs=input_batch[0],
                            units=300,
                            activation=tf.nn.relu,
                            kernel_regularizer=regularizer,
                            reuse=reuse,
                            name='layer1')
        y1 = tf.layers.dense(inputs=y,
                             units=100,
                             activation=tf.nn.relu,
                             kernel_regularizer=regularizer,
                             reuse=reuse,
                             name='layer2')
        logits = tf.layers.dense(inputs=y1,
                                 units=10,
                                 reuse=reuse,
                                 kernel_regularizer=regularizer,
                                 name='layer3')

    cross_entropy = tf.losses.sparse_softmax_cross_entropy(
        labels=input_batch[1], logits=logits)

    cross_entropy += tf.losses.get_regularization_loss()

    predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
    accuracy = tf.reduce_mean(
        tf.cast(tf.equal(input_batch[1], predictions), tf.float32))

    return cross_entropy, accuracy
Exemplo n.º 4
0
  def testSingleFCMaskAdded(self):
    input_depth, output_depth = 8, 32
    input_tensor = array_ops.ones((5, input_depth))
    layers.masked_fully_connected(input_tensor, output_depth)

    masks = ops.get_collection(core_layers.MASK_COLLECTION)
    self.assertEqual(len(masks), 1)
    self.assertListEqual(masks[0].get_shape().as_list(),
                         [input_depth, output_depth])

    masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
    self.assertEqual(len(masked_weight), 1)
    self.assertListEqual(masked_weight[0].get_shape().as_list(),
                         [input_depth, output_depth])
Exemplo n.º 5
0
    def testSingleFCMaskAdded(self):
        input_depth, output_depth = 8, 32
        input_tensor = array_ops.ones((5, input_depth))
        layers.masked_fully_connected(input_tensor, output_depth)

        masks = ops.get_collection(core_layers.MASK_COLLECTION)
        self.assertEqual(len(masks), 1)
        self.assertListEqual(masks[0].get_shape().as_list(),
                             [input_depth, output_depth])

        masked_weight = ops.get_collection(
            core_layers.MASKED_WEIGHT_COLLECTION)
        self.assertEqual(len(masked_weight), 1)
        self.assertListEqual(masked_weight[0].get_shape().as_list(),
                             [input_depth, output_depth])
Exemplo n.º 6
0
  def testMultipleConvMaskAdded(self):
    number_of_layers = 5

    base_depth = 4
    depth_step = 7

    input_tensor = array_ops.ones((8, base_depth))

    top_layer = input_tensor

    for ix in range(number_of_layers):
      top_layer = layers.masked_fully_connected(top_layer, base_depth +
                                                (ix + 1) * depth_step)

    masks = ops.get_collection(core_layers.MASK_COLLECTION)
    self.assertEqual(len(masks), number_of_layers)
    for ix in range(number_of_layers):
      self.assertListEqual(masks[ix].get_shape().as_list(), [
          base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step
      ])

    masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
    self.assertEqual(len(masked_weight), number_of_layers)
    for ix in range(number_of_layers):
      self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
          base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step
      ])
Exemplo n.º 7
0
  def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
                   freq_iter=2):
    """Setups a trivial training procedure for sparse training."""
    tf.reset_default_graph()
    optim = tf.train.GradientDescentOptimizer(0.1)
    sparse_optim = sparse_optimizers.SparseSETOptimizer(
        optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)
    x = tf.random.uniform((1, n_inp))
    y = layers.masked_fully_connected(x, n_out, activation_fn=None)
    global_step = tf.train.get_or_create_global_step()
    weight = pruning.get_weights()[0]
    # There is one masked layer to be trained.
    mask = pruning.get_masks()[0]
    # Around half of the values of the mask is set to zero with `mask_update`.
    mask_update = tf.assign(
        mask,
        tf.constant(
            np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]),
            dtype=tf.float32))
    loss = tf.reduce_mean(y)
    global_step = tf.train.get_or_create_global_step()
    train_op = sparse_optim.minimize(loss, global_step)

    # Init
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    sess.run([mask_update])

    return sess, train_op, mask, weight, global_step
Exemplo n.º 8
0
  def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
                   freq_iter=2, momentum=0.5):
    """Setups a trivial training procedure for sparse training."""
    tf.reset_default_graph()
    optim = tf.train.GradientDescentOptimizer(0.1)
    sparse_optim = sparse_optimizers.SparseMomentumOptimizer(
        optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac,
        momentum=momentum)
    x = tf.ones((1, n_inp))
    y = layers.masked_fully_connected(x, n_out, activation_fn=None)
    # Multiplying the output with range of constants to have constant but
    # different gradients at the masked weights.
    y = y * tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape)
    loss = tf.reduce_sum(y)
    global_step = tf.train.get_or_create_global_step()
    train_op = sparse_optim.minimize(loss, global_step)
    weight = pruning.get_weights()[0]
    masked_grad = sparse_optim._weight2masked_grads[weight.name]
    masked_grad_ema = sparse_optim._ema_grads.average(masked_grad)
    # Init
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)

    return sess, train_op, masked_grad_ema
Exemplo n.º 9
0
  def _setup_graph(self, default_sparsity, mask_init_method,
                   custom_sparsity_map, n_inp=3, n_out=5):
    """Setups a trivial training procedure for sparse training."""
    tf.reset_default_graph()
    optim = tf.train.GradientDescentOptimizer(1e-3)
    sparse_optim = sparse_optimizers.SparseSnipOptimizer(
        optim, default_sparsity, mask_init_method,
        custom_sparsity_map=custom_sparsity_map)

    inp_values = np.arange(1, n_inp+1)
    scale_vector_values = np.random.uniform(size=(n_out,)) - 0.5
    # The gradient is the outer product of input and the output gradients.
    # Since the loss is sample sum the output gradient is equal to the scale
    # vector.
    expected_grads = np.outer(inp_values, scale_vector_values)

    x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp))
    y = layers.masked_fully_connected(x, n_out, activation_fn=None)
    scale_vector = tf.constant(scale_vector_values, dtype=tf.float32)

    y = y * scale_vector
    loss = tf.reduce_sum(y)

    global_step = tf.train.get_or_create_global_step()
    train_op = sparse_optim.minimize(loss, global_step)

    # Init
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    mask = pruning.get_masks()[0]
    weights = pruning.get_weights()[0]
    return sess, train_op, expected_grads, sparse_optim, mask, weights
Exemplo n.º 10
0
    def _setup_graph(self,
                     n_inp,
                     n_out,
                     drop_frac,
                     start_iter=1,
                     end_iter=4,
                     freq_iter=2):
        """Setups a trivial training procedure for sparse training."""
        tf.reset_default_graph()
        optim = tf.train.GradientDescentOptimizer(1e-3)
        global_step = tf.train.get_or_create_global_step()
        sparse_optim = sparse_optimizers.SparseRigLOptimizer(
            optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)
        x = tf.ones((1, n_inp))
        y = layers.masked_fully_connected(x, n_out, activation_fn=None)
        # Multiplying the output with range of constants to have constant but
        # different gradients at the masked weights. We also multiply the loss with
        # global_step to increase the gradient linearly with time.
        scale_vector = (
            tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape) *
            tf.cast(global_step, dtype=y.dtype))
        y = y * scale_vector
        loss = tf.reduce_sum(y)
        global_step = tf.train.get_or_create_global_step()
        train_op = sparse_optim.minimize(loss, global_step)
        weight = pruning.get_weights()[0]
        expected_gradient = tf.broadcast_to(scale_vector, weight.shape)
        masked_grad = sparse_optim._weight2masked_grads[weight.name]

        # Init
        sess = tf.Session()
        init = tf.global_variables_initializer()
        sess.run(init)

        return sess, train_op, masked_grad, expected_gradient
Exemplo n.º 11
0
    def testMultipleConvMaskAdded(self):
        number_of_layers = 5

        base_depth = 4
        depth_step = 7

        input_tensor = array_ops.ones((8, base_depth))

        top_layer = input_tensor

        for ix in range(number_of_layers):
            top_layer = layers.masked_fully_connected(
                top_layer, base_depth + (ix + 1) * depth_step)

        masks = ops.get_collection(core_layers.MASK_COLLECTION)
        self.assertEqual(len(masks), number_of_layers)
        for ix in range(number_of_layers):
            self.assertListEqual(masks[ix].get_shape().as_list(), [
                base_depth + ix * depth_step, base_depth +
                (ix + 1) * depth_step
            ])

        masked_weight = ops.get_collection(
            core_layers.MASKED_WEIGHT_COLLECTION)
        self.assertEqual(len(masked_weight), number_of_layers)
        for ix in range(number_of_layers):
            self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
                base_depth + ix * depth_step, base_depth +
                (ix + 1) * depth_step
            ])
Exemplo n.º 12
0
    def pruning_inference(self, inputs):
        net = layers.masked_conv2d(inputs, 64, 3)
        net = layers.masked_conv2d(net, 64, 3)
        net = tf.layers.max_pooling2d(net, 2, 2)

        net = layers.masked_conv2d(net, 128, 3)
        net = layers.masked_conv2d(net, 128, 3)
        net = tf.layers.max_pooling2d(net, 2, 2)

        net = layers.masked_conv2d(net, 256, 3)
        net = layers.masked_conv2d(net, 256, 3)
        net = layers.masked_conv2d(net, 256, 3)
        net = tf.layers.max_pooling2d(net, 2, 2)

        net = tf.layers.flatten(net)

        net = layers.masked_fully_connected(net, 1024)
        net = layers.masked_fully_connected(net, 1024)
        logits = layers.masked_fully_connected(net, self.num_classes, activation_fn=None)

        return tf.identity(logits, name='logits')
Exemplo n.º 13
0
    def _build_fully_connected_model(self, number_of_layers):
        base_depth = 4
        depth_step = 7

        input_tensor = array_ops.ones((8, base_depth))

        top_layer = input_tensor

        with variable_scope.variable_scope("fc_model"):
            for ix in range(number_of_layers):
                top_layer = layers.masked_fully_connected(
                    top_layer, base_depth + (ix + 1) * depth_step)

        return top_layer
Exemplo n.º 14
0
  def _build_fully_connected_model(self, number_of_layers):
    base_depth = 4
    depth_step = 7

    input_tensor = array_ops.ones((8, base_depth))

    top_layer = input_tensor

    with variable_scope.variable_scope("fc_model"):
      for ix in range(number_of_layers):
        top_layer = layers.masked_fully_connected(
            top_layer, base_depth + (ix + 1) * depth_step)

    return top_layer
def dense(x,
          units,
          activation=None,
          use_bias=True,
          kernel_initializer="glorot_uniform",
          bias_initializer="zeros",
          sparsity_technique="variational_dropout",
          auxiliary_initializer=None,
          threshold=3.0,
          clip_alpha=None,
          training=True,
          dtype=tf.float32,
          name=None,
          initial_sparsity=None):
    """Matmul & bias add that supports broadcasting for batched gemm.

  Supports a contrained set of functionality provided by tf.layers.dense.

  Args:
    x: input tensor.
    units: number of units in the dense layer.
    activation: activation function to use in the layer.
    use_bias: whether or not to add a bias to the output.
    kernel_initializer: weight initializer for the layer.
    bias_initializer: weight initializer for the bias.
    sparsity_technique: sparsification technique to apply to the weights.
    auxiliary_initializer: initializer for auxiliary variables use in
      variational dropout and l0 regularization.
    threshold: log-alpha threshold for variational dropout.
    clip_alpha: whether to clip the alpha values for variational dropout.
    training: whether this run is training or evaluation the model.
    dtype: data type for the weights and computation.
    name: name for the layer.
    initial_sparsity: initial weight sparsity at the start of training.

  Returns:
    Tensor representing the output of the layer.
  """
    activation = activations.get(activation)
    kernel_initializer = initializers.get(kernel_initializer)
    bias_initializer = initializers.get(bias_initializer)

    if (sparsity_technique == "magnitude_pruning"
            or sparsity_technique == "random_pruning"):
        if initial_sparsity is not None:
            # If the initial sparsity value is passed in, use the sparse glorot
            # uniform initializer to account for the zero valued weights.
            kernel_initializer = common_init.SparseGlorotUniform(
                initial_sparsity, dtype=dtype)
            tf.logging.info(
                "Using sparse initialization with sparsity {} for variable {}".
                format(initial_sparsity,
                       tf.get_variable_scope().name))

        # If the sparsity technique is magnitude_pruning, or random_pruning
        # use the model_pruning masked_fully_connected layer
        #
        # masked_fully_connected doesn't take use_bias arg, pass None for the
        # bias initializer if we don't want a bias variable
        bias_initializer = bias_initializer if use_bias else None
        with tf.variable_scope(name, default_name="dense"):
            return pruning_layers.masked_fully_connected(
                inputs=x,
                num_outputs=units,
                activation_fn=activation,
                weights_initializer=kernel_initializer,
                biases_initializer=bias_initializer)
    if initial_sparsity is not None:
        raise ValueError("initial_sparsity only supported for mp & rp")

    # layer_name = "%s_{}" % name if name else "{}"

    input_shape = x.get_shape().as_list()
    if input_shape[-1] is None:
        raise ValueError("The last dimension of the inputs to `Dense` "
                         "should be defined. Found `None`.")

    with tf.variable_scope(name, default_name="dense") as vs:
        kernel = tf.get_variable("kernel",
                                 shape=[input_shape[-1], units],
                                 initializer=kernel_initializer,
                                 dtype=dtype,
                                 trainable=True)

        bias = None
        if use_bias:
            bias = tf.get_variable("bias",
                                   shape=[
                                       units,
                                   ],
                                   initializer=bias_initializer,
                                   dtype=dtype,
                                   trainable=True)

    # Compute the dense layer
    if sparsity_technique == "variational_dropout":
        log_sigma2_initializer = initializers.get(auxiliary_initializer)

        if not log_sigma2_initializer:
            log_sigma2_initializer = tf.constant_initializer(value=-10,
                                                             dtype=dtype)

        with tf.variable_scope(vs, auxiliary_name_scope=False) as vs1:
            with tf.name_scope(vs1.original_name_scope):
                log_sigma2 = tf.get_variable(
                    "log_sigma2",
                    shape=[input_shape[-1], units],
                    initializer=log_sigma2_initializer,
                    dtype=dtype,
                    trainable=True)

        variational_parameters = (kernel, log_sigma2)
        tf.add_to_collection(VARIATIONAL_DROPOUT_PARAMETERS,
                             variational_parameters)

        input_rank = x.get_shape().ndims
        if input_rank > 2:
            if training:
                outputs = vd.nn.broadcast_matmul_train(x,
                                                       variational_parameters,
                                                       clip_alpha=clip_alpha)
            else:
                outputs = vd.nn.broadcast_matmul_eval(x,
                                                      variational_parameters,
                                                      threshold)
        else:
            if training:
                outputs = vd.nn.matmul_train(x,
                                             variational_parameters,
                                             clip_alpha=clip_alpha)
            else:
                outputs = vd.nn.matmul_eval(x, variational_parameters,
                                            threshold)
    else:
        if sparsity_technique != "l0_regularization":
            raise ValueError(
                "Unsupported sparsity technique {}".format(sparsity_technique))
        log_alpha_initializer = initializers.get(auxiliary_initializer)

        if not log_alpha_initializer:
            # Default to \alpha / (\alpha + 1) equal to 0.5
            # Default to \alpha / (\alpha + 1) = .1
            log_alpha_initializer = tf.random_normal_initializer(mean=2.197,
                                                                 stddev=0.01,
                                                                 dtype=dtype)

        with tf.variable_scope(vs, auxiliary_name_scope=False) as vs1:
            with tf.name_scope(vs1.original_name_scope):
                log_alpha = tf.get_variable("log_alpha",
                                            shape=[input_shape[-1], units],
                                            initializer=log_alpha_initializer,
                                            dtype=dtype,
                                            trainable=True)

        weight_parameters = (kernel, log_alpha)
        tf.add_to_collection(L0_REGULARIZATION_PARAMETERS, weight_parameters)

        input_rank = x.get_shape().ndims
        if input_rank > 2:
            if training:
                outputs = l0.nn.broadcast_matmul_train(x, weight_parameters)
            else:
                outputs = l0.nn.broadcast_matmul_eval(x, weight_parameters)
        else:
            if training:
                outputs = l0.nn.matmul_train(x, weight_parameters)
            else:
                outputs = l0.nn.matmul_eval(x, weight_parameters)

    # Handle the bias and activation
    if use_bias:
        outputs = tf.nn.bias_add(outputs, bias)
    if activation is not None:
        return activation(outputs)
    return outputs
Exemplo n.º 16
0
def transformer_model(input_tensor,
                      attention_mask=None,
                      hidden_size=768,
                      num_hidden_layers=12,
                      num_attention_heads=12,
                      intermediate_size=3072,
                      intermediate_act_fn=gelu,
                      hidden_dropout_prob=0.1,
                      attention_probs_dropout_prob=0.1,
                      initializer_range=0.02,
                      do_return_all_layers=False):
    """Multi-headed, multi-layer Transformer from "Attention is All You Need".

  This is almost an exact implementation of the original Transformer encoder.

  See the original paper:
  https://arxiv.org/abs/1706.03762

  Also see:
  https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py

  Args:
    input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
    attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
      seq_length], with 1 for positions that can be attended to and 0 in
      positions that should not be.
    hidden_size: int. Hidden size of the Transformer.
    num_hidden_layers: int. Number of layers (blocks) in the Transformer.
    num_attention_heads: int. Number of attention heads in the Transformer.
    intermediate_size: int. The size of the "intermediate" (a.k.a., feed
      forward) layer.
    intermediate_act_fn: function. The non-linear activation function to apply
      to the output of the intermediate/feed-forward layer.
    hidden_dropout_prob: float. Dropout probability for the hidden layers.
    attention_probs_dropout_prob: float. Dropout probability of the attention
      probabilities.
    initializer_range: float. Range of the initializer (stddev of truncated
      normal).
    do_return_all_layers: Whether to also return all layers or just the final
      layer.

  Returns:
    float Tensor of shape [batch_size, seq_length, hidden_size], the final
    hidden layer of the Transformer.

  Raises:
    ValueError: A Tensor shape or parameter is invalid.
  """
    if hidden_size % num_attention_heads != 0:
        raise ValueError(
            "The hidden size (%d) is not a multiple of the number of attention "
            "heads (%d)" % (hidden_size, num_attention_heads))

    attention_head_size = int(hidden_size / num_attention_heads)
    input_shape = get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape[0]
    seq_length = input_shape[1]
    input_width = input_shape[2]

    # The Transformer performs sum residuals on all layers so the input needs
    # to be the same as the hidden size.
    if input_width != hidden_size:
        raise ValueError(
            "The width of the input tensor (%d) != hidden size (%d)" %
            (input_width, hidden_size))

    # We keep the representation as a 2D tensor to avoid re-shaping it back and
    # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
    # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
    # help the optimizer.
    prev_output = reshape_to_matrix(input_tensor)

    all_layer_outputs = []
    for layer_idx in range(num_hidden_layers):
        with tf.variable_scope("layer_%d" % layer_idx):
            layer_input = prev_output

            with tf.variable_scope("attention"):
                attention_heads = []
                with tf.variable_scope("self"):
                    attention_head = attention_layer(
                        from_tensor=layer_input,
                        to_tensor=layer_input,
                        attention_mask=attention_mask,
                        num_attention_heads=num_attention_heads,
                        size_per_head=attention_head_size,
                        attention_probs_dropout_prob=
                        attention_probs_dropout_prob,
                        initializer_range=initializer_range,
                        do_return_2d_tensor=True,
                        batch_size=batch_size,
                        from_seq_length=seq_length,
                        to_seq_length=seq_length)
                    attention_heads.append(attention_head)

                attention_output = None
                if len(attention_heads) == 1:
                    attention_output = attention_heads[0]
                else:
                    # In the case where we have other sequences, we just concatenate
                    # them to the self-attention head before the projection.
                    attention_output = tf.concat(attention_heads, axis=-1)

                # Run a linear projection of `hidden_size` then add a residual
                # with `layer_input`.
                with tf.variable_scope("output"):
                    attention_output = masked_fully_connected(
                        attention_output,
                        hidden_size,
                        weights_initializer=create_initializer(
                            initializer_range))
                    attention_output = dropout(attention_output,
                                               hidden_dropout_prob)
                    attention_output = layer_norm(attention_output +
                                                  layer_input)

            # The activation is only applied to the "intermediate" hidden layer.
            with tf.variable_scope("intermediate"):
                intermediate_output = masked_fully_connected(
                    attention_output,
                    intermediate_size,
                    activation_fn=intermediate_act_fn,
                    weights_initializer=create_initializer(initializer_range))

            # Down-project back to `hidden_size` then add the residual.
            with tf.variable_scope("output"):
                layer_output = masked_fully_connected(
                    intermediate_output,
                    hidden_size,
                    weights_initializer=create_initializer(initializer_range))
                layer_output = dropout(layer_output, hidden_dropout_prob)
                layer_output = layer_norm(layer_output + attention_output)
                prev_output = layer_output
                all_layer_outputs.append(layer_output)

    if do_return_all_layers:
        final_outputs = []
        for layer_output in all_layer_outputs:
            final_output = reshape_from_matrix(layer_output, input_shape)
            final_outputs.append(final_output)
        return final_outputs
    else:
        final_output = reshape_from_matrix(prev_output, input_shape)
        return final_output
Exemplo n.º 17
0
def attention_layer(from_tensor,
                    to_tensor,
                    attention_mask=None,
                    num_attention_heads=1,
                    size_per_head=512,
                    query_act=None,
                    key_act=None,
                    value_act=None,
                    attention_probs_dropout_prob=0.0,
                    initializer_range=0.02,
                    do_return_2d_tensor=False,
                    batch_size=None,
                    from_seq_length=None,
                    to_seq_length=None):
    """Performs multi-headed attention from `from_tensor` to `to_tensor`.

  This is an implementation of multi-headed attention based on "Attention
  is all you Need". If `from_tensor` and `to_tensor` are the same, then
  this is self-attention. Each timestep in `from_tensor` attends to the
  corresponding sequence in `to_tensor`, and returns a fixed-with vector.

  This function first projects `from_tensor` into a "query" tensor and
  `to_tensor` into "key" and "value" tensors. These are (effectively) a list
  of tensors of length `num_attention_heads`, where each tensor is of shape
  [batch_size, seq_length, size_per_head].

  Then, the query and key tensors are dot-producted and scaled. These are
  softmaxed to obtain attention probabilities. The value tensors are then
  interpolated by these probabilities, then concatenated back to a single
  tensor and returned.

  In practice, the multi-headed attention are done with transposes and
  reshapes rather than actual separate tensors.

  Args:
    from_tensor: float Tensor of shape [batch_size, from_seq_length,
      from_width].
    to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
    attention_mask: (optional) int32 Tensor of shape [batch_size,
      from_seq_length, to_seq_length]. The values should be 1 or 0. The
      attention scores will effectively be set to -infinity for any positions in
      the mask that are 0, and will be unchanged for positions that are 1.
    num_attention_heads: int. Number of attention heads.
    size_per_head: int. Size of each attention head.
    query_act: (optional) Activation function for the query transform.
    key_act: (optional) Activation function for the key transform.
    value_act: (optional) Activation function for the value transform.
    attention_probs_dropout_prob: (optional) float. Dropout probability of the
      attention probabilities.
    initializer_range: float. Range of the weight initializer.
    do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
      * from_seq_length, num_attention_heads * size_per_head]. If False, the
      output will be of shape [batch_size, from_seq_length, num_attention_heads
      * size_per_head].
    batch_size: (Optional) int. If the input is 2D, this might be the batch size
      of the 3D version of the `from_tensor` and `to_tensor`.
    from_seq_length: (Optional) If the input is 2D, this might be the seq length
      of the 3D version of the `from_tensor`.
    to_seq_length: (Optional) If the input is 2D, this might be the seq length
      of the 3D version of the `to_tensor`.

  Returns:
    float Tensor of shape [batch_size, from_seq_length,
      num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
      true, this will be of shape [batch_size * from_seq_length,
      num_attention_heads * size_per_head]).

  Raises:
    ValueError: Any of the arguments or tensor shapes are invalid.
  """
    def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
                             seq_length, width):
        output_tensor = tf.reshape(
            input_tensor, [batch_size, seq_length, num_attention_heads, width])

        output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
        return output_tensor

    from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
    to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])

    if len(from_shape) != len(to_shape):
        raise ValueError(
            "The rank of `from_tensor` must match the rank of `to_tensor`.")

    if len(from_shape) == 3:
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]
        to_seq_length = to_shape[1]
    elif len(from_shape) == 2:
        if (batch_size is None or from_seq_length is None
                or to_seq_length is None):
            raise ValueError(
                "When passing in rank 2 tensors to attention_layer, the values "
                "for `batch_size`, `from_seq_length`, and `to_seq_length` "
                "must all be specified.")

    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`

    from_tensor_2d = reshape_to_matrix(from_tensor)
    to_tensor_2d = reshape_to_matrix(to_tensor)

    # `query_layer` = [B*F, N*H]
    query_layer = masked_fully_connected(
        from_tensor_2d,
        num_attention_heads * size_per_head,
        activation_fn=query_act,
        scope="query",
        weights_initializer=create_initializer(initializer_range))

    # `key_layer` = [B*T, N*H]
    key_layer = masked_fully_connected(
        to_tensor_2d,
        num_attention_heads * size_per_head,
        activation_fn=key_act,
        scope="key",
        weights_initializer=create_initializer(initializer_range))

    # `value_layer` = [B*T, N*H]
    value_layer = masked_fully_connected(
        to_tensor_2d,
        num_attention_heads * size_per_head,
        activation_fn=value_act,
        scope="value",
        weights_initializer=create_initializer(initializer_range))

    # `query_layer` = [B, N, F, H]
    query_layer = transpose_for_scores(query_layer, batch_size,
                                       num_attention_heads, from_seq_length,
                                       size_per_head)

    # `key_layer` = [B, N, T, H]
    key_layer = transpose_for_scores(key_layer, batch_size,
                                     num_attention_heads, to_seq_length,
                                     size_per_head)

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
    # `attention_scores` = [B, N, F, T]
    attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
    attention_scores = tf.multiply(attention_scores,
                                   1.0 / math.sqrt(float(size_per_head)))

    if attention_mask is not None:
        # `attention_mask` = [B, 1, F, T]
        attention_mask = tf.expand_dims(attention_mask, axis=[1])

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0

        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        attention_scores += adder

    # Normalize the attention scores to probabilities.
    # `attention_probs` = [B, N, F, T]
    attention_probs = tf.nn.softmax(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = dropout(attention_probs, attention_probs_dropout_prob)

    # `value_layer` = [B, T, N, H]
    value_layer = tf.reshape(
        value_layer,
        [batch_size, to_seq_length, num_attention_heads, size_per_head])

    # `value_layer` = [B, N, T, H]
    value_layer = tf.transpose(value_layer, [0, 2, 1, 3])

    # `context_layer` = [B, N, F, H]
    context_layer = tf.matmul(attention_probs, value_layer)

    # `context_layer` = [B, F, N, H]
    context_layer = tf.transpose(context_layer, [0, 2, 1, 3])

    if do_return_2d_tensor:
        # `context_layer` = [B*F, N*H]
        context_layer = tf.reshape(context_layer, [
            batch_size * from_seq_length, num_attention_heads * size_per_head
        ])
    else:
        # `context_layer` = [B, F, N*H]
        context_layer = tf.reshape(
            context_layer,
            [batch_size, from_seq_length, num_attention_heads * size_per_head])

    return context_layer
def sparse_fully_connected(x,
                           units,
                           activation=None,
                           use_bias=True,
                           kernel_initializer=None,
                           kernel_regularizer=None,
                           bias_initializer=init_ops.zeros_initializer(),
                           biases_regularizer=None,
                           sparsity_technique='baseline',
                           log_sigma2_initializer=None,
                           log_alpha_initializer=None,
                           threshold=3.0,
                           clip_alpha=None,
                           is_training=False,
                           name=None):
    """Constructs sparse_fully_connected with any desired pruning method.

  Args:
    x: Input, float32 tensor.
    units: Int representing size of output tensor.
    activation: If None, a linear activation is used.
    use_bias: Boolean specifying whether bias vector should be used.
    kernel_initializer: Initializer for the convolution weights.
    kernel_regularizer: Regularization method for the convolution weights.
    bias_initializer: Initalizer of the bias vector.
    biases_regularizer: Optional regularizer for the bias vector.
    sparsity_technique: Method used to introduce sparsity. ['baseline',
      'threshold', 'variational_dropout', 'l0_regularization']
    log_sigma2_initializer: Specified initializer of the log_sigma2 term used
      in variational dropout.
    log_alpha_initializer: Specified initializer of the log_alpha term used
      in l0 regularization.
    threshold: Threshold for masking variational dropout log alpha at test time.
    clip_alpha: Int that specifies range for clippling variational dropout
      log alpha values.
    is_training: Boolean specifying whether it is training or eval.
    name: String speciying name scope of layer in network.

  Returns:
    Output: activations.

  Raises:
    ValueError: If the rank of the input is not greater than 2.
  """

    layer_variable_getter = variable_getter({
        'bias': 'biases',
        'kernel': 'weights',
    })

    with tf.variable_scope(name,
                           'Dense', [x],
                           custom_getter=layer_variable_getter) as sc:

        input_shape = x.get_shape().as_list()
        if input_shape[-1] is None:
            raise ValueError('The last dimension of the inputs to `Dense` '
                             'should be defined. Found `None`.')

        pruning_methods = ['threshold']

        if sparsity_technique in pruning_methods:
            return layers.masked_fully_connected(
                inputs=x,
                num_outputs=units,
                activation_fn=activation,
                weights_initializer=kernel_initializer,
                weights_regularizer=kernel_regularizer,
                biases_initializer=bias_initializer,
                biases_regularizer=biases_regularizer,
                outputs_collections=None,
                trainable=True,
                scope=sc)

        elif sparsity_technique == 'variational_dropout':
            vd_fc = vd.layers.FullyConnected(
                num_outputs=units,
                activation=activation,
                kernel_initializer=kernel_initializer,
                kernel_regularizer=kernel_regularizer,
                bias_initializer=bias_initializer,
                bias_regularizer=biases_regularizer,
                log_sigma2_initializer=log_sigma2_initializer,
                is_training=is_training,
                use_bias=use_bias,
                clip_alpha=clip_alpha,
                threshold=threshold,
                trainable=True,
                name=sc)
            return vd_fc.apply(x)
        elif sparsity_technique == 'l0_regularization':
            l0_fc = l0.layers.FullyConnected(
                num_outputs=units,
                activation=activation,
                kernel_initializer=kernel_initializer,
                kernel_regularizer=kernel_regularizer,
                bias_initializer=bias_initializer,
                bias_regularizer=biases_regularizer,
                log_alpha_initializer=log_alpha_initializer,
                is_training=is_training,
                use_bias=use_bias,
                trainable=True,
                name=sc)
            return l0_fc.apply(x)
        elif sparsity_technique == 'baseline':
            return tf.layers.dense(inputs=x,
                                   units=units,
                                   activation=activation,
                                   use_bias=use_bias,
                                   kernel_initializer=kernel_initializer,
                                   kernel_regularizer=kernel_regularizer,
                                   bias_initializer=bias_initializer,
                                   bias_regularizer=biases_regularizer,
                                   name=name)
        else:
            raise ValueError(
                'Unsupported sparsity technique {}'.format(sparsity_technique))
Exemplo n.º 19
0
 2: 'Pullover',
 3: 'Dress',
 4: 'Coat',
 5: 'Sandal',
 6: 'Shirt',
 7: 'Sneaker',
 8: 'Bag',
 9: 'Ankle boot'
}

# Define Placeholders
images = tf.placeholder(tf.float32, [None, 784])
labels = tf.placeholder(tf.float32, [None, 10])

# Define the model
layer1 = layers.masked_fully_connected(images, 128)
layer2 = layers.masked_fully_connected(layer1, 128)
logits = layers.masked_fully_connected(layer2, len(label_dict))


def loss_fun():
    """
    Loss function, softmax cross entropy is used
    """
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels))
    return loss


def train(loss, global_step, lr=1e-3):
    """
    Training op. Make sure that the same global step is used for pruning!
Exemplo n.º 20
0
def sparse_fully_connected(x,
                           units,
                           activation=None,
                           use_bias=True,
                           kernel_initializer=None,
                           kernel_regularizer=None,
                           bias_initializer=init_ops.zeros_initializer(),
                           biases_regularizer=None,
                           sparsity_technique='baseline',
                           name=None):
  """Constructs sparse_fully_connected with any desired pruning method.

  Args:
    x: Input, float32 tensor.
    units: Int representing size of output tensor.
    activation: If None, a linear activation is used.
    use_bias: Boolean specifying whether bias vector should be used.
    kernel_initializer: Initializer for the convolution weights.
    kernel_regularizer: Regularization method for the convolution weights.
    bias_initializer: Initalizer of the bias vector.
    biases_regularizer: Optional regularizer for the bias vector.
    sparsity_technique: Method used to introduce sparsity. ['baseline',
      'threshold']
    name: String speciying name scope of layer in network.

  Returns:
    Output: activations.

  Raises:
    ValueError: If the rank of the input is not greater than 2.
  """

  layer_variable_getter = variable_getter({
      'bias': 'biases',
      'kernel': 'weights',
  })

  with tf.variable_scope(
      name, 'Dense', [x], custom_getter=layer_variable_getter) as sc:

    input_shape = x.get_shape().as_list()
    if input_shape[-1] is None:
      raise ValueError('The last dimension of the inputs to `Dense` '
                       'should be defined. Found `None`.')

    pruning_methods = ['threshold']

    if sparsity_technique in pruning_methods:
      return layers.masked_fully_connected(
          inputs=x,
          num_outputs=units,
          activation_fn=activation,
          weights_initializer=kernel_initializer,
          weights_regularizer=kernel_regularizer,
          biases_initializer=bias_initializer,
          biases_regularizer=biases_regularizer,
          outputs_collections=None,
          trainable=True,
          scope=sc)

    elif sparsity_technique == 'baseline':
      return tf.layers.dense(
          inputs=x,
          units=units,
          activation=activation,
          use_bias=use_bias,
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_initializer=bias_initializer,
          bias_regularizer=biases_regularizer,
          name=name)
    else:
      raise ValueError(
          'Unsupported sparsity technique {}'.format(sparsity_technique))
Exemplo n.º 21
0
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data

epochs = 200
batch_size = 40000
model_path_unpruned = "Model_Saves/Unpruned.ckpt"
model_path_pruned = "Model_Saves/Pruned.ckpt"

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batches = int(len(mnist.train.images) / batch_size)

image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])

# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = layers.masked_fully_connected(layer2, 10)

global_step = tf.train.get_or_create_global_step()
reset_global_step_op = tf.assign(global_step, 0)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))

train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step)

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

pruning_hparams = pruning.get_pruning_hparams()
print("Pruning Hyperparameters:", pruning_hparams)
    def model_bulid(self, height, width, channel, classes):
        x = tf.placeholder(dtype=tf.float32,
                           shape=[None, height, width, channel])
        y = tf.placeholder(dtype=tf.float32, shape=[None, classes])

        # conv 1 ,if image Nx465x128x1 ,(conv 5x5 32 ,pool/2)
        conv1_1 = tf.nn.relu(
            self.conv_layer(x,
                            ksize=[5, 5, channel, 32],
                            stride=[1, 1, 1, 1],
                            padding="SAME",
                            name="conv1_1"))  # Nx465x128x1 ==>   Nx465x128x32
        pool1_1 = self.pool_layer(conv1_1,
                                  ksize=[1, 2, 2, 1],
                                  stride=[1, 2, 2, 1],
                                  name="pool1_1")  # N*232x64x32

        # conv 2,(conv 5x5 32)=>(conv 5x5 64, pool/2)
        conv2_1 = tf.nn.relu(
            self.conv_layer(pool1_1,
                            ksize=[5, 5, 32, 64],
                            stride=[1, 1, 1, 1],
                            padding="SAME",
                            name="conv2_1"))
        pool2_1 = self.pool_layer(conv2_1,
                                  ksize=[1, 2, 2, 1],
                                  stride=[1, 2, 2, 1],
                                  name="pool2_1")  # Nx116x32x128

        # Flatten
        ft = self.flatten(pool2_1)

        # Dense layer,(fc 100)=>=>(fc classes) and prune optimize
        fc_layer1 = layers.masked_fully_connected(ft, 200)
        fc_layer2 = layers.masked_fully_connected(fc_layer1, 100)
        prediction = layers.masked_fully_connected(fc_layer2, 10)

        loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction,
                                                       labels=y))
        #  original Dense layer
        # fc1 = self.fc_layer(ft,fc_dims=100,name="fc1")
        # finaloutput = self.finlaout_layer(fc1,fc_dims=10,name="final")

        #  pruning op
        global_step = tf.train.get_or_create_global_step()
        reset_global_step_op = tf.assign(global_step, 0)
        # Get, Print, and Edit Pruning Hyperparameters
        pruning_hparams = pruning.get_pruning_hparams()
        print("Pruning Hyper parameters:", pruning_hparams)
        # Change hyperparameters to meet our needs
        pruning_hparams.begin_pruning_step = 0
        pruning_hparams.end_pruning_step = 250
        pruning_hparams.pruning_frequency = 1
        pruning_hparams.sparsity_function_end_step = 250
        pruning_hparams.target_sparsity = .9
        # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
        p = pruning.Pruning(pruning_hparams, global_step=global_step)
        prune_op = p.conditional_mask_update_op()

        # optimize
        LEARNING_RATE_BASE = 0.001
        LEARNING_RATE_DECAY = 0.9
        LEARNING_RATE_STEP = 300
        gloabl_steps = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                                   gloabl_steps,
                                                   LEARNING_RATE_STEP,
                                                   LEARNING_RATE_DECAY,
                                                   staircase=True)
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            optimize = tf.train.AdamOptimizer(
                learning_rate=learning_rate).minimize(loss, global_step)

        # prediction
        prediction_label = prediction
        correct_prediction = tf.equal(tf.argmax(prediction_label, 1),
                                      tf.argmax(y, 1))
        accurary = tf.reduce_mean(tf.cast(correct_prediction,
                                          dtype=tf.float32))
        correct_times_in_batch = tf.reduce_mean(
            tf.cast(correct_prediction, dtype=tf.int32))

        return dict(x=x,
                    y=y,
                    optimize=optimize,
                    correct_prediction=prediction_label,
                    correct_times_in_batch=correct_times_in_batch,
                    cost=loss,
                    accurary=accurary,
                    prune_op=prune_op)
Exemplo n.º 23
0
    def __init__(self,
                 config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 token_type_ids=None,
                 use_one_hot_embeddings=False,
                 scope=None):
        """Constructor for BertModel.

    Args:
      config: `BertConfig` instance.
      is_training: bool. true for training model, false for eval model. Controls
        whether dropout will be applied.
      input_ids: int32 Tensor of shape [batch_size, seq_length].
      input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
      token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
      use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
        embeddings or tf.embedding_lookup() for the word embeddings.
      scope: (optional) variable scope. Defaults to "bert".

    Raises:
      ValueError: The config is invalid or one of the input tensor shapes
        is invalid.
    """
        config = copy.deepcopy(config)
        if not is_training:
            config.hidden_dropout_prob = 0.0
            config.attention_probs_dropout_prob = 0.0

        input_shape = get_shape_list(input_ids, expected_rank=2)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length],
                                 dtype=tf.int32)

        if token_type_ids is None:
            token_type_ids = tf.zeros(shape=[batch_size, seq_length],
                                      dtype=tf.int32)

        with tf.variable_scope(scope, default_name="bert"):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                (self.embedding_output,
                 self.embedding_table) = embedding_lookup(
                     input_ids=input_ids,
                     vocab_size=config.vocab_size,
                     embedding_size=config.hidden_size,
                     initializer_range=config.initializer_range,
                     word_embedding_name="word_embeddings",
                     use_one_hot_embeddings=use_one_hot_embeddings)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.embedding_output = embedding_postprocessor(
                    input_tensor=self.embedding_output,
                    use_token_type=True,
                    token_type_ids=token_type_ids,
                    token_type_vocab_size=config.type_vocab_size,
                    token_type_embedding_name="token_type_embeddings",
                    use_position_embeddings=True,
                    position_embedding_name="position_embeddings",
                    initializer_range=config.initializer_range,
                    max_position_embeddings=config.max_position_embeddings,
                    dropout_prob=config.hidden_dropout_prob)

            with tf.variable_scope("encoder"):
                # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
                # mask of shape [batch_size, seq_length, seq_length] which is used
                # for the attention scores.
                attention_mask = create_attention_mask_from_input_mask(
                    input_ids, input_mask)

                # Run the stacked transformer.
                # `sequence_output` shape = [batch_size, seq_length, hidden_size].
                self.all_encoder_layers = transformer_model(
                    input_tensor=self.embedding_output,
                    attention_mask=attention_mask,
                    hidden_size=config.hidden_size,
                    num_hidden_layers=config.num_hidden_layers,
                    num_attention_heads=config.num_attention_heads,
                    intermediate_size=config.intermediate_size,
                    intermediate_act_fn=get_activation(config.hidden_act),
                    hidden_dropout_prob=config.hidden_dropout_prob,
                    attention_probs_dropout_prob=config.
                    attention_probs_dropout_prob,
                    initializer_range=config.initializer_range,
                    do_return_all_layers=True)

            self.sequence_output = self.all_encoder_layers[-1]
            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_size, seq_length, hidden_size] to a tensor of shape
            # [batch_size, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = tf.squeeze(self.sequence_output[:,
                                                                     0:1, :],
                                                axis=1)
                self.pooled_output = masked_fully_connected(
                    first_token_tensor,
                    config.hidden_size,
                    activation_fn=tf.tanh,
                    weights_initializer=create_initializer(
                        config.initializer_range))