def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
    """Layer normalization over dimension dim.

  Args:
    x: a mtf.Tensor whose shape contains dim.
    dim: a mtf.Dimension
    epsilon: a floating point number
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
    with tf.variable_scope(name + "/layer_norm"):
        scale = mtf.get_variable(x.mesh,
                                 "layer_norm_scale",
                                 mtf.TensorShape([dim]),
                                 initializer=tf.ones_initializer(),
                                 activation_dtype=x.dtype)
        bias = mtf.get_variable(x.mesh,
                                "layer_norm_bias",
                                mtf.TensorShape([dim]),
                                initializer=tf.zeros_initializer(),
                                activation_dtype=x.dtype)
        reduced_shape = x.shape - dim
        mean = mtf.reduce_mean(x, output_shape=reduced_shape)
        variance = mtf.reduce_mean(mtf.square(x - mean),
                                   output_shape=reduced_shape)
        norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)
        return norm_x * scale + bias
Exemple #2
0
 def _embedding_and_softmax_vars(self, mesh):
   hparams = self._hparams
   targets_embedding_var = mtf.get_variable(
       mesh, "targets_embedding",
       mtf.Shape([self.targets_vocab_dim, self.model_dim]),
       initializer=tf.random_normal_initializer(),
       activation_dtype=self.activation_dtype)
   if self.has_input:
     if hparams.shared_embedding:
       inputs_embedding_var = targets_embedding_var
     else:
       inputs_embedding_var = mtf.get_variable(
           mesh, "inputs_embedding",
           mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
           initializer=tf.random_normal_initializer(),
           activation_dtype=self.activation_dtype)
   else:
     inputs_embedding_var = None
   if hparams.shared_embedding_and_softmax_weights:
     softmax_var = targets_embedding_var * (self.model_dim.size ** -0.5)
   else:
     softmax_var = mtf.get_variable(
         mesh,
         "softmax",
         mtf.Shape([self.targets_vocab_dim, self.model_dim]),
         initializer=tf.random_normal_initializer(
             stddev=self.model_dim.size**-0.5),
         activation_dtype=self.activation_dtype)
   positional_embedding_var = mtf.get_variable(
       mesh, "positional_embedding",
       mtf.Shape([self.max_length_dim, self.model_dim]),
       initializer=tf.random_normal_initializer(),
       activation_dtype=self.activation_dtype)
   return (inputs_embedding_var, targets_embedding_var,
           softmax_var, positional_embedding_var)
Exemple #3
0
    def create_positional_emb_2d(self, targets, max_length_dim, model_dim):
        """Learned 2d positional embedding for images."""
        mesh = targets.mesh
        hparams = self._hparams
        activation_dtype = self.set_activation_type()

        rows_dim = mtf.Dimension("rows", hparams.img_len)
        cols_dim = mtf.Dimension("cols",
                                 hparams.img_len * hparams.num_channels)

        positional_emb_rows_var = mtf.get_variable(
            mesh,
            "positional_emb_rows",
            mtf.Shape([max_length_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)
        positional_emb_cols_var = mtf.get_variable(
            mesh,
            "positional_emb_cols",
            mtf.Shape([max_length_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        targets_position_x = mtf.range(mesh, rows_dim, dtype=tf.int32)
        targets_position_y = mtf.range(mesh, cols_dim, dtype=tf.int32)
        position_x = mtf.broadcast(
            mtf.gather(positional_emb_rows_var, targets_position_x,
                       max_length_dim),
            mtf.Shape([rows_dim, cols_dim, model_dim]))

        position_y = mtf.broadcast(
            mtf.gather(positional_emb_cols_var, targets_position_y,
                       max_length_dim),
            mtf.Shape([rows_dim, cols_dim, model_dim]))
        return position_x + position_y
 def testGraph(self):
     graph = mtf.Graph()
     self.assertLen(graph.operations, 0)
     self.assertLen(graph.tensors, 0)
     self.assertLen(graph.trainable_variables, 0)
     self.assertLen(graph.all_variables, 0)
     mesh = mtf.Mesh(graph, "mesh_test")
     _ = mtf.import_tf_tensor(mesh,
                              tf_tensor=tf.constant(0.),
                              shape=mtf.Shape([]))
     self.assertLen(graph.operations, 1)
     self.assertLen(graph.tensors, 1)
     self.assertLen(graph.trainable_variables, 0)
     self.assertLen(graph.all_variables, 0)
     _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
     self.assertLen(graph.operations, 2)
     self.assertLen(graph.tensors, 2)
     self.assertLen(graph.trainable_variables, 1)
     self.assertLen(graph.all_variables, 1)
     _ = mtf.get_variable(mesh,
                          "variable_1",
                          mtf.Shape([]),
                          trainable=False)
     self.assertLen(graph.operations, 3)
     self.assertLen(graph.tensors, 3)
     self.assertLen(graph.trainable_variables, 1)
     self.assertLen(graph.all_variables, 2)
    def create_positional_emb_2d(self, targets):
        """Learned 2d positional embedding for images."""
        mesh = targets.mesh

        positional_emb_rows_var = mtf.get_variable(
            mesh,
            "positional_emb_rows",
            mtf.Shape([self.max_length_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=self.activation_type)
        positional_emb_cols_var = mtf.get_variable(
            mesh,
            "positional_emb_cols",
            mtf.Shape([self.max_length_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=self.activation_type)

        targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32)
        targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32)
        position_x = mtf.broadcast(
            mtf.gather(positional_emb_rows_var, targets_position_x,
                       self.max_length_dim),
            mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))

        position_y = mtf.broadcast(
            mtf.gather(positional_emb_cols_var, targets_position_y,
                       self.max_length_dim),
            mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))
        return position_x + position_y
def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None):
    """Batch normalization.

  Args:
    x: a mtf.Tensor whose shape contains [batch_dim, ..., dim]
    is_training: a boolean, whether mode is training.
    momentum: a floating point number, specifying batch norm decay value.
    epsilon: a floating point number.
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
    with tf.variable_scope(name, default_name="batch_norm", values=[x]):
        batch_dim = x.shape.dims[0]
        reduced_shape = x.shape - batch_dim
        scale = mtf.get_variable(x.mesh,
                                 "batch_norm_scale",
                                 mtf.Shape([batch_dim]),
                                 initializer=tf.ones_initializer(),
                                 activation_dtype=x.dtype)
        bias = mtf.get_variable(x.mesh,
                                "batch_norm_bias",
                                mtf.Shape([batch_dim]),
                                initializer=tf.zeros_initializer(),
                                activation_dtype=x.dtype)

        moving_mean = mtf.get_variable(
            x.mesh,
            "moving_mean",
            reduced_shape,
            initializer=tf.random_normal_initializer(stddev=1.0),
            activation_dtype=x.dtype,
            trainable=False)
        moving_variance = mtf.get_variable(x.mesh,
                                           "moving_variance",
                                           reduced_shape,
                                           initializer=tf.ones_initializer(),
                                           activation_dtype=x.dtype,
                                           trainable=False)

        # At training time, calculate mean and variance and normalize across batch
        # dim.
        if is_training:
            mean = mtf.reduce_mean(x, output_shape=reduced_shape)
            variance = mtf.reduce_mean(mtf.square(x - mean),
                                       output_shape=reduced_shape)
            norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)

            # Update running mean and running variance.
            moving_mean = mtf.assign(
                moving_mean, momentum * moving_mean + (1 - momentum) * mean)
            moving_variance = mtf.assign(
                moving_variance,
                momentum * moving_variance + (1 - momentum) * variance)
        else:
            # At eval and test time, use the running mean and variance.
            norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon)
        return norm_x * scale + bias
Exemple #7
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]),
                             mtf.Shape([batch_dim, rows_dim, cols_dim]))
    x = mtf.reshape(x, [batch_dim, rows_dim, cols_dim, one_channel_dim])

    # add some convolutional layers to demonstrate that convolution works.
    # TODO(noam): get spatially-partitioned convolution working.
    fh_dim = mtf.Dimension("fh", 3)
    fw_dim = mtf.Dimension("fw", 3)
    filters1_dim = mtf.Dimension("filters1", 32)
    filters2_dim = mtf.Dimension("filters2", 32)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(mtf.conv2d(x, kernel1))
    f2 = mtf.relu(mtf.conv2d(f1, kernel2))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

    h1 = mtf_layers.dense(x,
                          hidden_dim1,
                          reduced_dims=[rows_dim, cols_dim],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf_layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim]))
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
def dense(x,
          output_dim,
          reduced_dims=None,
          expert_dims=None,
          use_bias=True,
          activation=None,
          name=None):
    """Dense layer doing (kernel*x + bias) computation.

  Args:
    x: a mtf.Tensor of shape [..., reduced_dims].
    output_dim: a mtf.Dimension
    reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If
      omitted, we reduce the last dimension.
    expert_dims: an optional list of mtf.Dimension which represent different
      experts. Different experts get different weights.
    use_bias: a boolean, whether to add bias.
    activation: an optional function from mtf.Tensor to mtf.Tensor
    name: a string. variable scope.

  Returns:
    a mtf.Tensor of shape [..., output_dim].
  """
    if expert_dims is None:
        expert_dims = []
    if reduced_dims is None:
        reduced_dims = x.shape.dims[-1:]
    w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim])
    output_shape = mtf.Shape(
        [d for d in x.shape.dims if d not in reduced_dims] + [output_dim])

    with tf.variable_scope(name, default_name="dense"):
        stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5
        w = mtf.get_variable(
            x.mesh,
            "kernel",
            w_shape,
            initializer=tf.random_normal_initializer(stddev=stddev),
            activation_dtype=x.dtype)
        y = mtf.einsum([x, w], output_shape)
        if use_bias:
            b = mtf.get_variable(x.mesh,
                                 "bias",
                                 mtf.Shape(expert_dims + [output_dim]),
                                 initializer=tf.zeros_initializer(),
                                 activation_dtype=x.dtype)
            y += b
        if activation is not None:
            y = activation(y)
        return y
Exemple #9
0
def multihead_attention_vars(
    mesh, heads, io_channels, kv_channels, activation_dtype):
  """Create Parameters for Multihead Attention.

  Args:
    mesh: a Mesh
    heads: a Dimension
    io_channels: a Dimension
    kv_channels: a Dimension
    activation_dtype: a tf.dtype

  Returns:
    q_var: a Tensor with shape [heads, io_channels, kv_channels]
    k_var: a Tensor with shape [heads, io_channels, kv_channels]
    v_var: a Tensor with shape [heads, io_channels, kv_channels]
    o_var: a Tensor with shape [heads, io_channels, kv_channels]
  """
  qkvo = mtf.Dimension("qkvo", 4)
  qk_stddev = (io_channels.size ** -0.5) * (kv_channels.size ** -0.25)
  v_stddev = io_channels.size ** -0.5
  o_stddev = (io_channels.size * heads.size) ** -0.5
  def qkvo_initializer(shape,
                       dtype=None,
                       partition_info=None,
                       verify_shape=None):
    del partition_info, verify_shape
    return tf.random_normal(shape, dtype=dtype) * tf.reshape(
        [qk_stddev, qk_stddev, v_stddev, o_stddev], [4, 1, 1, 1])
  var = mtf.get_variable(
      mesh, "qkvo", mtf.Shape([qkvo, heads, io_channels, kv_channels]),
      initializer=qkvo_initializer, activation_dtype=activation_dtype)
  q_var, k_var, v_var, o_var = mtf.unstack(var, qkvo)
  return q_var, k_var, v_var, o_var
Exemple #10
0
def dense_relu_dense(x,
                     hidden_channels,
                     dropout=0.0,
                     dropout_broadcast_dims=None,
                     name=None):
  """Hidden layer with ReLU activation followed by linear projection.

  The output has the same number of channels as the input.

  Args:
    x: a mtf.Tensor
    hidden_channels: a mtf.Dimension - channels in the hidden layer
    dropout: an optional float
    dropout_broadcast_dims: an optional list of mtf.Dimension
    name: an optional string

  Returns:
    a mtf.Tensor with the same shape as x.
  """
  with tf.variable_scope(name, default_name="dense_relu_dense"):
    io_channels = x.shape.dims[-1]
    stddev = (hidden_channels.size * io_channels.size) ** -0.25
    io = mtf.Dimension("io", 2)
    w = mtf.get_variable(
        x.mesh,
        "kernel",
        mtf.Shape([io, io_channels, hidden_channels]),
        initializer=tf.random_normal_initializer(stddev=stddev),
        activation_dtype=x.dtype)
    wi, wo = mtf.unstack(w, io)
    h = mtf.relu(mtf.einsum([x, wi]))
    if dropout != 0.0:
      h = mtf.dropout(h, 1.0 - dropout,
                      noise_shape=h.shape - dropout_broadcast_dims)
    return mtf.einsum([h, wo])
Exemple #11
0
    def test_variable_placer(self):
        sizes = [100, 0, 0, 0]
        device_list = ['cpu:0', 'cpu:1', 'cpu:2', 'cpu:3']

        with tf.Graph().as_default() as g:
            var_placer = mtf_utils.BalancedVariablePlacer(device_list, sizes)
            graph = mtf.Graph()
            mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

            hidden_dim = mtf.Dimension('hidden', 10)
            output_dim = mtf.Dimension('output_feature', 10)

            for i in xrange(5):
                # Each variable takes 400 Bytes, and will be placed from cpu:1.
                mtf.get_variable(mesh, 'w{}'.format(i),
                                 [hidden_dim, output_dim])

            for i in xrange(5):
                var = g.get_tensor_by_name('w{}:0'.format(i))
                device = (i + 1) % len(device_list)
                self.assertEqual('cpu:{}'.format(device), var.device)
Exemple #12
0
  def _decoder_layer_stack_incremental(self,
                                       x,
                                       step_num,
                                       encdec_tensors,
                                       self_attention_k,
                                       self_attention_v,
                                       encdec_attention_mask=None):
    """Decoder layer stack during inference.

    We are processing only one position at a time.

    The self-attention keys and values have already been computed for
    previous positions.  In addition to the decoder output, we need to
    produce the updated self-attention keys and values.

    If there is an encoder, then additional Tensors are supplied in
    encdec_tensors, which give us the keys and values for encoder-decoder
    attention as well as the weight matrices q_var and o_var.

    Args:
      x: a mtf.Tensor with shape [batch_dim, model_dim]
      step_num: an mtf integer Scalar
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v)
      self_attention_k: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      self_attention_v: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.

    Returns:
      y: a mtf.Tensor with shape [batch_dim, model_dim]
      new_self_attention_k: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_k
      new_self_attention_v: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_v

    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    num_layers = hparams.num_decoder_layers
    num_layer_norms = num_layers * (2 if encdec_tensors is None else 3) + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    new_self_attention_k = []
    new_self_attention_v = []
    for layer in range(num_layers):
      with tf.variable_scope("layer_%d" % layer):
        # Self attention layer
        y, new_k, new_v = mtf_layers.multihead_self_attention_incremental(
            normalize(x),
            prev_k=self_attention_k[layer],
            prev_v=self_attention_v[layer],
            step_num=step_num,
            name="self_attention")
        new_self_attention_k.append(new_k)
        new_self_attention_v.append(new_v)
        x += y
        if encdec_tensors is not None:
          # Encoder-Decoder attention layer
          q_var, o_var, k, v = encdec_tensors[layer]
          x += mtf_layers.multihead_encdec_attention_incremental(
              normalize(x),
              q_var, o_var, k, v,
              encdec_attention_mask,
              name="encdec_attention")
        # ffn layer
        x += self._feedforward_layer(normalize(x), hparams)
    x = normalize(x)
    assert not layer_norm_vars
    return x, new_self_attention_k, new_self_attention_v
Exemple #13
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
    """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
    shortcut = inputs

    filter_h_dim = mtf.Dimension("filter_height", 3)
    filter_w_dim = mtf.Dimension("filter_width", 3)
    one_h_dim = mtf.Dimension("filter_height", 1)
    one_w_dim = mtf.Dimension("filter_width", 1)

    if projection_shortcut is not None:
        filters_dim = mtf.Dimension("filtersp", filters)
        kernel = mtf.get_variable(
            inputs.mesh, "kernel",
            mtf.Shape(
                [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
        shortcut = projection_shortcut(inputs, kernel)

    # First conv block
    filters1_dim = mtf.Dimension("filters1", filters)
    kernel1 = mtf.get_variable(
        inputs.mesh, "kernel1",
        mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel1,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    # TODO(nikip): Add Dropout?
    inputs = batch_norm_relu(inputs, is_training)

    # Second conv block
    filters2_dim = mtf.Dimension("filters2", filters)
    kernel2 = mtf.get_variable(
        inputs.mesh, "kernel2",
        mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel2,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=row_blocks_dim,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training)

    # Third wide conv filter block
    filters3_dim = mtf.Dimension("filters3", filters)
    filters3_kernel = mtf.get_variable(
        inputs.mesh, "wide_kernel",
        mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    filters3_kernel,
                                    strides,
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training, relu=False)

    # TODO(nikip): Maybe add residual with a projection?
    return mtf.relu(inputs + mtf.rename_dimension(
        shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
Exemple #14
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        filter_h_dim = mtf.Dimension("filter_height", 7)
        filter_w_dim = mtf.Dimension("filter_width", 7)
        filters = mtf.Dimension("filters", hparams.filter_sizes[0])
        rows_dim = mtf.Dimension("rows_size", 32)
        cols_dim = mtf.Dimension("cols_size", 96)
        row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
        col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
        classes_dim = mtf.Dimension("classes", 10)
        one_channel_dim = mtf.Dimension("one_channel", 1)

        inputs = features["inputs"]
        x = mtf.import_tf_tensor(
            mesh,
            tf.reshape(inputs, [
                hparams.batch_size, hparams.row_blocks,
                hparams.rows_size // hparams.row_blocks, hparams.col_blocks,
                hparams.num_channels * hparams.cols_size // hparams.col_blocks,
                1
            ]),
            mtf.Shape([
                batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
                one_channel_dim
            ]))
        x = mtf.transpose(x, [
            batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
            one_channel_dim
        ])

        x = mtf.to_float(x)
        initial_filters = mtf.get_variable(
            mesh, "init_filters",
            mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters]))
        x = mtf.conv2d_with_blocks(x,
                                   initial_filters,
                                   strides=[1, 1, 1, 1],
                                   padding="SAME",
                                   h_blocks_dim=None,
                                   w_blocks_dim=col_blocks_dim)

        x = batch_norm_relu(x, is_training)

        # Conv blocks
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_layers):
            layer_name = "block_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Residual block layer
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[0],
                                blocks=hparams.layer_sizes[0],
                                strides=[1, 1, 1, 1],
                                is_training=is_training,
                                name="block_layer1",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[1],
                                blocks=hparams.layer_sizes[1],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer2",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[2],
                                blocks=hparams.layer_sizes[2],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer3",
                                row_blocks_dim=None,
                                col_blocks_dim=None)

        # Calculate the logits and loss.
        out = x
        outputs = mtf_layers.dense(out,
                                   hidden_dim,
                                   reduced_dims=out.shape.dims[-5:],
                                   activation=mtf.relu,
                                   name="dense")

        # We assume fixed vocab size for targets
        labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [hparams.batch_size]),
                                      mtf.Shape([batch_dim]))

        logits = mtf_layers.dense(outputs, classes_dim, name="logits")
        soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, classes_dim)

        # Reshape logits so it doesn't break inside t2t.
        logits = mtf.reshape(
            logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
        loss = mtf.reduce_mean(loss)
        return logits, loss
Exemple #15
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()

        # We assume fixed vocab size for targets
        targets_vocab_size = self._problem_hparams.target_modality._vocab_size  # pylint: disable=protected-access
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        model_dim = mtf.Dimension("d_model", hparams.hidden_size)
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        length_dim = mtf.Dimension("length", length)
        max_length_dim = mtf.Dimension("max_length", hparams.max_length)
        filter_dim = mtf.Dimension("d_ff", hparams.d_ff)
        kv_channels = mtf.Dimension("kv_channels", hparams.d_kv)
        heads = mtf.Dimension("heads", hparams.num_heads)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim, length_dim]),
                                        name=name)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(x,
                               keep_prob=1.0 -
                               hparams.layer_prepostprocess_dropout,
                               noise_shape=mtf.Shape([batch_dim, model_dim]))

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        targets_vocab_size = 256 * hparams.num_channels
        targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size)
        outputs_vocab_dim = mtf.Dimension("output_vocab", 256)

        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([targets_vocab_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       targets_vocab_dim)
        # Add positional embeddings
        x += mtf.reshape(
            self.create_positional_emb_2d(targets, max_length_dim, model_dim),
            [length_dim, model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            vocab_size = hparams.num_classes
            inputs_vocab_dim = mtf.Dimension("vocab", vocab_size)
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf_layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([inputs_vocab_dim, model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_decoder_layers):
            layer_name = "decoder_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Self attention layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.masked_local_attention_1d(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_self_att"),
                        None,
                        kv_channels,
                        heads,
                        block_length=hparams.block_length,
                        name="self_att"))
                # ffn layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.dense_relu_dense(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_ffn"),
                        filter_dim,
                        hparams.dropout,
                        dropout_broadcast_dims=[length_dim]))

        x = mtf_layers.layer_norm(x,
                                  model_dim,
                                  name="decoder_final_layer_norm")

        # Calculate the logits and loss.
        logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits")
        soft_targets = mtf.one_hot(targets,
                                   outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, outputs_vocab_dim)

        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l
        return logits, loss
Exemple #16
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 4)
    col_blocks_dim = mtf.Dimension("col_blocks", 4)
    rows_dim = mtf.Dimension("rows_size", 7)
    cols_dim = mtf.Dimension("cols_size", 7)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    fh_dim = mtf.Dimension("fh", 9)
    fw_dim = mtf.Dimension("fw", 9)
    filters1_dim = mtf.Dimension("filters1", 16)
    filters2_dim = mtf.Dimension("filters2", 16)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(
        mtf.conv2d_with_blocks(x,
                               kernel1,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    f2 = mtf.relu(
        mtf.conv2d_with_blocks(f1,
                               kernel2,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

    h1 = mtf_layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-4:],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf_layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Exemple #17
0
    def apply_grad(self, grad, var):
        # create slots
        factored_dims = self._factored_dims(var.shape)
        if factored_dims:
            d0, d1 = factored_dims
            vr_shape = var.shape - d0
            vc_shape = var.shape - d1
            vr = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vr",
                                  vr_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
            vc = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vc",
                                  vc_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
        else:
            v = mtf.get_variable(var.mesh,
                                 var.name + "_slot_v",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)
        if self._beta1:
            m = mtf.get_variable(var.mesh,
                                 var.name + "_slot_m",
                                 var.shape,
                                 iniitalizer=tf.zeros_initializer(),
                                 trainable=False)

        with tf.variable_scope(var.name + "/adafactor"):
            grad_squared = mtf.square(grad) + self._epsilon1
            decay_rate = self._decay_rate
            old_val = var.value
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(
                    old_val) * self._learning_rate
            else:
                update_scale = self._learning_rate
            mixing_rate = 1.0 - decay_rate
            updates = []
            if factored_dims:
                grad_squared_row_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vr_shape)
                grad_squared_col_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vc_shape)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                vr_update = mtf.assign(vr, new_vr)
                vc_update = mtf.assign(vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1)
                r_factor = mtf.rsqrt(new_vr / long_term_mean)
                c_factor = mtf.rsqrt(new_vc)
                x = grad * r_factor * c_factor
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                v_update = mtf.assign(v, new_v)
                updates.append(v_update)
                x = grad * mtf.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = mtf.maximum(
                    1.0,
                    reduce_rms(x) / self._clipping_threshold)
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = self._beta1 * m.value + (1.0 -
                                                 self._beta1) * subtrahend
                subtrahend = new_m
                updates.append(mtf.assign(m, new_m))
            new_val = old_val - subtrahend
            var_update = mtf.assign(var, new_val)
            updates.append(var_update)
            return updates
Exemple #18
0
  def _layer_stack(self,
                   x,
                   num_layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [batch_dim, length_dim, model_dim]
      num_layers: an integer
      encoder_output: an optional mtf.Tensor with shape
        [batch_dim, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
    Returns:
      a mtf.Tensor with shape [batch_dim, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))
    num_layer_norms = num_layers * (2 if encoder_output is None else 3) + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    for layer in range(num_layers):
      with tf.variable_scope("layer_%d" % layer):
        # Self attention layer
        x += layer_prepostprocess_dropout(
            mtf_layers.multihead_attention(
                normalize(x), None,
                self_attention_mask, self.kv_dim, self.heads_dim,
                dropout=hparams.attention_dropout,
                dropout_broadcast_dims=[self.length_dim],
                name="self_attention"))
        if encoder_output is not None:
          # Encoder-Decoder attention layer
          x += layer_prepostprocess_dropout(
              mtf_layers.multihead_attention(
                  normalize(x), encoder_output,
                  encdec_attention_mask, self.kv_dim, self.heads_dim,
                  dropout=hparams.attention_dropout,
                  dropout_broadcast_dims=[self.length_dim],
                  name="encdec_attention"))
        # ffn layer
        x += layer_prepostprocess_dropout(
            self._feedforward_layer(normalize(x), losses=losses))
    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    return x