def benchmark_model(mesh):
    """
  Initializes a 3D volume with random noise, and execute a forward FFT
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    x_dim = mtf.Dimension("nx", FLAGS.cube_size)
    y_dim = mtf.Dimension("ny", FLAGS.cube_size)
    z_dim = mtf.Dimension("nz", FLAGS.cube_size)

    tx_dim = mtf.Dimension("tnx", FLAGS.cube_size)
    ty_dim = mtf.Dimension("tny", FLAGS.cube_size)
    tz_dim = mtf.Dimension("tnz", FLAGS.cube_size)

    # Create field
    field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim])

    # Apply FFT
    fft_field = mpm.fft3d(mtf.cast(field, tf.complex64),
                          [tx_dim, ty_dim, tz_dim])

    # Inverse FFT
    rfield = mtf.cast(mpm.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32)

    # Compute errors
    err = mtf.reduce_max(mtf.abs(field - rfield))
    return err
Example #2
0
def benchmark_model(mesh):
  """
  Initializes a 3D volume with random noise, and execute a forward FFT
  """
  batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
  x_dim = mtf.Dimension("nx", FLAGS.cube_size)
  y_dim = mtf.Dimension("ny", FLAGS.cube_size)
  z_dim = mtf.Dimension("nz", FLAGS.cube_size)

  tx_dim = mtf.Dimension("tnx", FLAGS.cube_size)
  ty_dim = mtf.Dimension("tny", FLAGS.cube_size)
  tz_dim = mtf.Dimension("tnz", FLAGS.cube_size)

  # Create field
  field = mtf.random_normal(mesh, [batch_dim, x_dim, y_dim, z_dim])

  input_field = field
  field = mtf.cast(field, tf.complex64)
  err = 0
  # Performs several back and forth FFTs in the same session
  for i in range(FLAGS.n_ffts):
    # Apply FFT
    fft_field = mpm.fft3d(field, [tx_dim, ty_dim, tz_dim])
    # Inverse FFT
    field = mpm.ifft3d(fft_field * 1, [x_dim, y_dim, z_dim])
    err += mtf.reduce_max(mtf.abs(mtf.cast(field, tf.float32) - input_field))

  field = mtf.cast(field, tf.float32)
  # Compute errors
  err += mtf.reduce_max(mtf.abs(field - input_field))
  return err
Example #3
0
  def compute_mask(self, context, memory_position):
    """Compute attention mask.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
    Returns:
      a Tensor or None
    """
    masks = []
    min_relative_position = self.min_relative_position(context)
    max_relative_position = self.max_relative_position(context)
    if max_relative_position is not None or min_relative_position is not None:
      relative_position = memory_position - context.position
      if min_relative_position is not None:
        illegal = mtf.less(relative_position, min_relative_position)
        masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9)
      if max_relative_position is not None:
        illegal = mtf.greater(relative_position, max_relative_position)
        masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9)
    if (context.sequence_id is not None and
        isinstance(context.sequence_id, mtf.Tensor) and
        context.length_dim in context.sequence_id.shape):
      masks.append(mtf.cast(
          mtf.not_equal(
              context.sequence_id,
              self.rename_length_to_memory_length(
                  context.sequence_id, context)),
          context.activation_dtype) * -1e9)
    return mtf.add_n(masks) if masks else None
Example #4
0
 def get_project_to_cluster_length(self, cluster_mask, dtype):
     """Returns projection from length dim to the shorter cluster length dim."""
     seq_length_dim = cluster_mask.shape.get_dim_by_name("length")
     cluster_length_dim = self.get_cluster_length_dim(seq_length_dim)
     return mtf.cast(cluster_mask, dtype) * mtf.one_hot(
         mtf.cumsum(mtf.cast(cluster_mask, tf.int32), seq_length_dim) - 1,
         output_dim=cluster_length_dim,
         dtype=dtype)
Example #5
0
def deep(x, mask, float16=None):
    x = mtf.einsum([x, mask], output_shape=x.shape.dims, name='deep_mul')
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))

    # 使用仿照mindspore中使用fp16来计算下面的dense
    x = mtf.cast(x, dtype=tf.float16)

    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim0', size=1024),
                         name="deep-dense-0",
                         reduced_dims=x.shape.dims[-2:],
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim1', size=512),
                         name="deep-dense-1",
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim2', size=256),
                         name="deep-dense-2",
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim3', size=128),
                         name="deep-dense-3",
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim4', size=1),
                         name="deep-dense-4",
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    if float16:
        pass
    else:
        x = mtf.cast(x, dtype=tf.float32)
    return x
Example #6
0
def model_backbone(features, labels, mesh):
    """The model.
	Args:
		image: tf.Tensor with shape [batch, 32*32]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
    id_hldr, wt_hldr = features

    batch_dim = mtf.Dimension("batch", args_opt.batch_size)
    field_dim = mtf.Dimension("field", size=39)
    vocab_dim = mtf.Dimension("vocab_size", 200000)
    embed_dim = mtf.Dimension("embed_size", 80)
    outdim = mtf.Dimension("outdim", 1)
    id_hldr = mtf.import_tf_tensor(
        mesh, tf.reshape(id_hldr, [args_opt.batch_size, field_dim.size]),
        mtf.Shape([batch_dim, field_dim]))
    wt_hldr = mtf.import_tf_tensor(
        mesh, tf.reshape(wt_hldr, [args_opt.batch_size, field_dim.size]),
        mtf.Shape([batch_dim, field_dim]))
    if args_opt.fp16:
        float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16)
        # id_hldr=mtf.cast(id_hldr,dtype=tf.int32)
        wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16)
    else:
        float16 = None

    logits, embedding_table = network[args_opt.model](id_hldr,
                                                      wt_hldr,
                                                      vocab_dim,
                                                      embed_dim,
                                                      outdim,
                                                      float16=float16)
    logits = mtf.cast(logits, dtype=tf.float32)
    embedding_table = mtf.cast(embedding_table, dtype=tf.float32)
    if labels is None:
        wide_loss = None
        deep_loss = None
    else:
        labels = mtf.import_tf_tensor(
            mesh, tf.reshape(labels, [args_opt.batch_size]),
            mtf.Shape([batch_dim]))
        wide_loss = mtf.layers.sigmoid_cross_entropy_with_logits(
            logits, labels)
        deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2
        deep_loss = mtf.reduce_mean(wide_loss) + 8e-5 * deep_loss
        wide_loss = mtf.reduce_mean(wide_loss)

    return logits, wide_loss + deep_loss
Example #7
0
    def call(self, context, x, losses=None):
        """Call the layer."""
        wq, wk, wv, wo = mtf.layers.multihead_attention_params(
            context.mesh, self.heads_dim, context.model_dim, self.kv_dim,
            context.variable_dtype)
        memory_length = mtf.Dimension("memory_length", context.length_dim.size)
        q = mtf.einsum([x, wq], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            m = x
        else:
            m = mtf.rename_dimension(x, context.length_dim.name,
                                     "memory_length")
        k = mtf.einsum([m, wk], reduced_dims=[context.model_dim])
        v = mtf.einsum([m, wv], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            old_k, old_v = context.get_states(2)
            one_hot = mtf.one_hot(context.position,
                                  memory_length,
                                  dtype=context.activation_dtype)
            inv_one_hot = 1.0 - one_hot
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot
        if context.mode == "incremental" or context.mode == "first_part":
            context.record_new_states([k, v])
        masks = []
        if context.autoregressive:
            masks.append(
                mtf.cast(
                    mtf.less(
                        context.position,
                        mtf.range(context.mesh, memory_length,
                                  dtype=tf.int32)), context.activation_dtype) *
                -1e9)
        if (context.sequence_id is not None
                and isinstance(context.sequence_id, mtf.Tensor)
                and context.length_dim in context.sequence_id.shape):
            masks.append(
                mtf.cast(
                    mtf.not_equal(
                        context.sequence_id,
                        mtf.layers.rename_length_to_memory_length(
                            context.sequence_id)), context.activation_dtype) *
                -1e9)
        mask = mtf.add_n(masks) if masks else None

        o = mtf.layers.dot_product_attention_v2(
            q, k, v, memory_length, self.kv_dim, self.kv_dim, mask,
            self.dropout_rate if context.train else 0.0, [context.length_dim])
        return mtf.einsum([o, wo],
                          x.shape,
                          reduced_dims=[self.heads_dim, self.kv_dim])
    def add_position_timing_signal_func(self, context, x, step):
        """Add n-dimensional embedding as the position (horizontal) timing signal.

    Args:
      context: mtf context
      x: a tensor with shape [batch, length, depth]
      step: step

    Returns:
      a Tensor with the same shape as x.

    """

        if not self.position_start_index:
            index = 0

        elif self.position_start_index == "random":
            # Shift all positions randomly
            # TODO(dehghani): What would be reasonable for max number of shift?
            index = mtf.random_uniform(context.mesh, [],
                                       maxval=x.shape.dims[1].size,
                                       dtype=tf.int32)

        elif self.position_start_index == "step":
            # Shift positions based on the step
            if self.recurrence_type == "act":
                num_steps = self.act_max_steps
            else:
                num_steps = self.num_rec_steps
            index = mtf.cast(x.shape.dims[1].size * step / num_steps,
                             dtype=tf.int32)

        length = context.length_dim
        channels = context.model.model_dim
        signal = self.get_timing_signal_1d(context,
                                           length,
                                           channels,
                                           start_index=index)

        if self.add_or_concat_timing_signal == "add":
            x_with_timing = x + mtf.cast(signal, x.dtype)
        # Unimplemented
        if self.add_or_concat_timing_signal == "concat":
            batch_dim = x.shape.dims[0]
            out_shape = mtf.Shape([batch_dim] + signal.shape.dims[1:])
            signal_tiled = mtf.broadcast(signal, out_shape)
            x_with_timing = mtf.concat(
                (x, signal_tiled),
                concat_dim_name=signal_tiled.dimension_names[-1])

        return x_with_timing
Example #9
0
def sample_categorical(x, dim=None):
    dim = x.shape[-1] if dim is None else dim

    cdf = mtf.cumsum(x, dim)
    rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1)
    mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32)
    return mtf.argmax(mask, dim)
Example #10
0
def mnist_model(image, labels, mesh):
	"""The model.
	Args:
		image: tf.Tensor with shape [batch, 28*28]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
	batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
	rows_dim = mtf.Dimension("rows_size", image_height)
	cols_dim = mtf.Dimension("cols_size", image_width)
	channel_dim = mtf.Dimension("image_channel", num_channels)
	classes_dim = mtf.Dimension(name='classesnum',size=classesnum)
	x = mtf.import_tf_tensor(
		mesh, tf.reshape(image, [FLAGS.batch_size, image_height, image_width, num_channels]),
		mtf.Shape(
			[batch_dim, rows_dim, cols_dim, channel_dim]))
	# x = mtf.transpose(x, [batch_dim, rows_dim, cols_dim, channel_dim])
	# print(x.shape)
	logits = VGG(x, classes_dim=classes_dim,depth=depth)
	logits = mtf.cast(logits,dtype=tf.float32)

	if labels is None:
		loss = None
	else:
		labels = mtf.import_tf_tensor(
			mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
		loss = mtf.layers.softmax_cross_entropy_with_logits(
			logits, mtf.one_hot(labels, classes_dim), classes_dim)
		loss = mtf.reduce_mean(loss)
	return logits, loss
Example #11
0
def c2r3d(cfield, dims, norm=None, dtype=tf.float32, name=None):
    """
  Converts a complex Fourier domain field to a real field

  Parameters:
  -----------
  cfield: tensor (batch_size, nc, nc, nc)
    Complex 3D real field

  norm: float
    Normalization factor

  dtype: tf.dtype
    Type of output tensor

  Return:
  -------
  rfield: tensor (batch_size, nc, nc, nc)
    Real valued field
  """
    x_dim, y_dim, z_dim = cfield.shape[-3:]
    if norm is None:
        norm = mtf.constant(cfield.mesh, x_dim.size * y_dim.size * z_dim.size)
    rfield = mtf.cast(mesh_ops.ifft3d(cfield, dims), dtype) * norm
    return rfield
Example #12
0
def r2c3d(rfield, k_dims, norm=None, dtype=tf.complex64):
    """
  Converts a real field to its complex Fourier Transform

  Parameters:
  -----------
  rfield: tensor (batch_size, nc, nc, nc)
    Input 3D real field

  norm: float
    Normalization factor

  dtype: tf.dtype
    Type of output tensor

  Return:
  -------
  cfield: tensor (batch_size, nc, nc, nc)
    Complex field
  """
    x_dim, y_dim, z_dim = rfield.shape[-3:]
    if norm is None:
        norm = mtf.constant(rfield.mesh, x_dim.size * y_dim.size * z_dim.size)
    cfield = mesh_ops.fft3d(mtf.cast(rfield / norm, dtype), k_dims)
    return cfield
Example #13
0
    def forward(self, features, return_loss=True, return_logits=False):
        inputs = features["tokens"]
        tokens = self.positional_embedding(self.embedding(inputs, "embedding"),
                                           "positional_embedding")

        mask = self.get_attn_mask(tokens.mesh, tokens.shape[1],
                                  self.dimensions["memory_len_dim"])
        out = self.transformer(tokens, mask=mask)
        logits = self.to_logits(out)
        if not return_loss:
            return logits

        labels = pad(inputs, [0, 1],
                     dim_name="total_seq_dim",
                     pad_value=self.eos_token_id)
        indices = mtf.range(labels.mesh,
                            mtf.Dimension("range", labels.shape[1].size - 1),
                            tf.int32,
                            name="labels_indices") + 1
        labels = mtf.gather(labels, indices, dim=labels.shape[1])
        labels = mtf.rename_dimension(labels, "range", "total_seq_dim")
        loss, loss_batch = self._loss(logits, labels)
        if return_logits and return_loss:
            # Cast back to checkpoint dtype
            logits = mtf.cast(logits, self.variable_dtype.master_dtype)
            return loss, loss_batch, logits
        return loss, loss_batch
Example #14
0
 def to_logits(self, x):
     with tf.variable_scope("to_logits"):
         logits = self.linear(self.layer_norm(x),
                              self.dimensions["final_vocab_dim"],
                              name="linear_out")
         # Go to full precision for the logits
         return mtf.cast(logits, tf.float32)
Example #15
0
def toy_model(features, mesh):
  """A toy model implemented by mesh tensorlfow."""
  batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
  io_dim = mtf.Dimension('io', FLAGS.io_size)

  master_dtype = tf.as_dtype(FLAGS.master_dtype)
  slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
  activation_dtype = tf.as_dtype(FLAGS.activation_dtype)

  x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
  x = mtf.cast(x, activation_dtype)
  h = x
  for lnum in xrange(1, FLAGS.num_hidden_layers + 2):
    if lnum + 1 == FLAGS.num_hidden_layers + 2:
      # output layer
      dim = io_dim
    elif lnum % 2 == 0:
      dim = mtf.Dimension('hidden_even', FLAGS.hidden_size)
    else:
      dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size)
    h = mtf.layers.dense(
        h, dim,
        use_bias=False,
        master_dtype=master_dtype,
        slice_dtype=slice_dtype,
        name='layer_%d' % lnum)
  y = h

  loss = mtf.reduce_mean(mtf.square(y - x))
  return y, loss
Example #16
0
    def sample(self, features, mesh):
        hparams = self._hparams
        model = self.model()

        def import_feature(key):
            return self._import_feature(features, mesh, key)

        if self.autoregressive:
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = import_feature("inputs")
            if partial_targets is None:
                partial_targets = import_feature("targets")
            if partial_targets:
                partial_targets *= mtf.cast(mtf.not_equal(partial_targets, 1),
                                            partial_targets.dtype)
            else:
                ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
                partial_targets = mtf.constant(mesh,
                                               0,
                                               ids_shape,
                                               dtype=tf.int32)
            if hparams.beam_size > 1:
                raise NotImplementedError(
                    "Beam search not implemented for unitransformer.")
            ret = model.sample_autoregressive(
                partial_targets,
                temperature=hparams.sampling_temp,
                variable_dtype=self.variable_dtype)
            return self.combine_batch_dims(ret)
        else:
            raise ValueError(
                "Don't know how to sample from non-autoregressive unitransformer"
            )
 def call(self, context, x):
     """Call the layer stack."""
     if isinstance(context.sequence_id, mtf.Tensor):
         # We use this mask to zero out the padding regions at each layer.
         # This "fixes" a bug where extreme values leak from the padding into the
         # non-padding regions.
         # TODO(noam): understand this better and make a more principled fix.
         mask = mtf.cast(mtf.not_equal(context.sequence_id, 0),
                         context.activation_dtype)
     else:
         mask = None
     x = self._dropout(context, x)
     context.layer_outputs.append(x)
     if self.mix_with_transformer_before_ut:
         for _ in range(self.num_vanilla_transformer_layers):
             x = self.vanilla_transformer_layer(context, x, mask)
     # Call a ACT layer
     if self.recurrence_type == "act":
         x = self.act_layer(context, x, mask)
     elif self.recurrence_type == "basic":
         x = self.ut_basic(context, x, mask)
     elif self.recurrence_type == "highway":
         layer_inputs = (x, x, x)
         x = self.ut_highway(context, layer_inputs, mask)
     if self.mix_with_transformer_after_ut:
         for _ in range(self.num_vanilla_transformer_layers):
             x = self.vanilla_transformer_layer(context, x, mask)
     x = self._layer_norm(context, x, name="final_layer_norm")
     x = self._dropout(context, x)
     if mask:
         x *= mask
     context.layer_outputs.append(x)
     return x
    def add_step_timing_signal_func(self, context, x, step):
        """Add n-dimensional embedding as the step (vertical) timing signal.

    Args:
      context: mtf context
      x: a tensor with shape [batch, length, depth]
      step: step

    Returns:
      a Tensor with the same shape as x.

    """
        if self.recurrence_type == "act":
            num_steps = self.act_max_steps
        else:
            num_steps = self.num_rec_steps
        channels = x.shape.dims[-1]

        if self.step_timing_signal_type == "learned":
            signal = self.get_layer_timing_signal_learned_1d(
                context, channels, step, num_steps)
        elif self.step_timing_signal_type == "sinusoid":
            signal = self.get_layer_timing_signal_sinusoid_1d(
                context, channels, step, num_steps)
        if self.add_or_concat_timing_signal == "add":
            x_with_timing = x + mtf.cast(signal, x.dtype)
        elif self.add_or_concat_timing_signal == "concat":
            batch_dim = x.shape.dims[0]
            out_shape = mtf.Shape([batch_dim] + x.shape.dims[1:])
            signal_tiled = mtf.broadcast(signal, out_shape)
            x_with_timing = mtf.concat(
                (x, signal_tiled),
                concat_dim_name=signal_tiled.dimension_names[-1])

        return x_with_timing
Example #19
0
 def call(self, context, x):
     """Call the layer stack."""
     if isinstance(context.sequence_id, mtf.Tensor):
         # We use this mask to zero out the padding regions at each layer.
         # This "fixes" a bug where extreme values leak from the padding into the
         # non-padding regions.
         # TODO(noam): undertand this better and make a more principled fix.
         mask = mtf.cast(mtf.not_equal(context.sequence_id, 0),
                         context.activation_dtype)
     else:
         mask = None
     x = self._dropout(context, x)
     if context.layer_outputs is not None:
         context.layer_outputs.append(x)
     for lnum, layer in enumerate(self._layers):
         with tf.variable_scope("layer_%03d" % lnum):
             norm_x = self._layer_norm(context, (x * mask) if mask else x)
             with tf.variable_scope(layer.__class__.__name__):
                 y = layer.call(context, norm_x)
                 if y.shape != x.shape:
                     raise ValueError(
                         "Layer %s returned misshaped output x=%s y=%s" %
                         (layer.__class__.__name__, x, y))
             x += self._dropout(context, y)
         if context.layer_outputs is not None and lnum != len(
                 self._layers) - 1:
             context.layer_outputs.append(x)
         context.layer_index += 1
     x = self._layer_norm(context, x, name="final_layer_norm")
     x = self._dropout(context, x)
     if mask:
         x *= mask
     if context.layer_outputs is not None:
         context.layer_outputs.append(x)
     return x
Example #20
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = mtf.layers.multihead_attention_params(context.mesh,
                                                    self.heads_dim,
                                                    context.model_dim,
                                                    self.kv_dim,
                                                    context.variable_dtype)
     if context.mode == "incremental":
         prev_k, prev_v = context.get_states(2)
         y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
             x, prev_k, prev_v, context.position, params=params)
         context.record_new_states([new_k, new_v])
         return y
     else:
         kv = []
         y = mtf.layers.masked_local_attention_1d(x,
                                                  self.kv_dim,
                                                  self.heads_dim,
                                                  self.window_size,
                                                  params=params,
                                                  return_kv=kv)
         if context.mode == "first_part":
             k = kv[0]
             v = kv[1]
             window_dim = mtf.Dimension("window", self.window_size)
             mesh = k.mesh
             window_pos = mtf.range(mesh, window_dim, tf.int32)
             pos = mtf.range(mesh, context.length_dim, tf.int32)
             select_recent = mtf.cast(
                 mtf.equal(window_pos, mtf.mod(pos, self.window_size)),
                 k.dtype)
             select_recent *= mtf.cast(
                 mtf.less(pos, context.initial_position), k.dtype)
             select_recent *= mtf.cast(
                 mtf.greater_equal(
                     pos, context.initial_position - self.window_size),
                 k.dtype)
             state_shape = k.shape.dims[:-2] + [window_dim, self.kv_dim]
             k_state = mtf.einsum([k, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.extend([k_state, v_state])
         return y
Example #21
0
 def get_attn_mask(self, mesh, nd, ns):
     if not exists(self.attn_mask):
         i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size
         j = mtf.range(mesh, ns, tf.int32)
         i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))
         self.attn_mask = mtf.cast(mtf.less(
             i, j), self.variable_dtype.activation_dtype) * -1e10
     return self.attn_mask
Example #22
0
 def nonpadding(self):
   """Tensor with zeros in padding positions and ones elsewhere."""
   if self.sequence_id is None:
     return None
   if self.sequence_id == 1:
     return 1
   else:
     return mtf.cast(
         mtf.not_equal(self.sequence_id, 0), self.activation_dtype)
Example #23
0
    def compute_loss(self, decoder: transformer.Unitransformer,
                     hidden: mtf.Tensor, targets: mtf.Tensor,
                     context: transformer.Context) -> mtf.Tensor:
        """Returns the loss without computing a softmax over the entire vocab."""
        loss = 0
        tail_cluster_masks = []
        for cluster in self._tail_clusters:
            cluster_mask = cluster.get_cluster_mask(targets)
            tail_cluster_masks.append(cluster_mask)

            if cluster.length_projection_factor == 1:
                targets_in_cluster = mtf.where(cluster_mask, targets, 0)
                hidden_in_cluster = mtf.where(cluster_mask, hidden, 0)
            else:
                # TODO(mmatena): Unfold the batch dim to get a super long sequence dim
                # to reduce the risk of overflowing the projection.
                proj_to_cluster_len = cluster.get_project_to_cluster_length(
                    cluster_mask, dtype=targets.dtype)
                targets_in_cluster = mtf.einsum(
                    [proj_to_cluster_len, targets],
                    reduced_dims=[targets.shape.get_dim_by_name("length")])
                hidden_in_cluster = mtf.einsum(
                    [mtf.cast(proj_to_cluster_len, hidden.dtype), hidden],
                    reduced_dims=[hidden.shape.get_dim_by_name("length")])

            loss += cluster.compute_loss(decoder, hidden_in_cluster,
                                         targets_in_cluster, context)

        tail_clusters_dim = mtf.Dimension("tail_clusters",
                                          len(tail_cluster_masks))
        tail_node_targets = mtf.reduce_sum(
            mtf.stack([(self._head_cluster.end_token_id + i) *
                       mtf.cast(mask, targets.dtype)
                       for i, mask in enumerate(tail_cluster_masks)],
                      tail_clusters_dim.name),
            reduced_dim=tail_clusters_dim)
        head_targets = mtf.where(mtf.cast(tail_node_targets, tf.bool),
                                 tail_node_targets, targets)
        loss += self._head_cluster.compute_loss(decoder, hidden, head_targets,
                                                context)

        return loss
Example #24
0
def biasmask_attn_weights(mesh, nd, ns, variable_dtype):
    # The old mask_attn_weights applied directly to the QK;
    # this returns a bias that the attention code from mtf adds to the attention matrix.
    # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
    # n_src and n_dest are both the same, i.e equal to sequence length
    # We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T
    # Information flows from k and v (memory_length) to q (sequence)
    i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size
    j = mtf.range(mesh, ns, tf.int32)
    i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))
    dtype = variable_dtype.activation_dtype
    return mtf.cast(mtf.less(i, j), dtype) * -1e10
Example #25
0
    def compute_mask(self, context, memory_position):
        """Compute attention mask.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
    Returns:
      a Tensor or None
    """
        masks = []
        min_relative_position = self.min_relative_position(context)
        max_relative_position = self.max_relative_position(context)
        if max_relative_position is not None or min_relative_position is not None:
            relative_position = memory_position - context.position
            if min_relative_position is not None:
                illegal = mtf.less(relative_position, min_relative_position)
                masks.append(
                    mtf.cast(illegal, context.activation_dtype) * -1e9)
            if max_relative_position is not None:
                illegal = mtf.greater(relative_position, max_relative_position)
                masks.append(
                    mtf.cast(illegal, context.activation_dtype) * -1e9)
        sequence_id = None
        # Subsequence id should only be set if we are in the decoder and have
        # multiple targets per input. This will allow each sub-target to only attend
        # to itself.
        if isinstance(context.subsequence_id, mtf.Tensor):
            sequence_id = context.subsequence_id
        elif isinstance(context.sequence_id, mtf.Tensor):
            sequence_id = context.sequence_id
        if (sequence_id is not None
                and context.length_dim in sequence_id.shape):
            masks.append(
                mtf.cast(
                    mtf.not_equal(
                        sequence_id,
                        self.rename_length_to_memory_length(
                            sequence_id, context)), context.activation_dtype) *
                -1e9)
        return mtf.add_n(masks) if masks else None
def visibility_mask_to_attention_bias(visible, dtype):
  """Convert a boolean visibility mask to an attention bias.

  The returned Tensor has large negative values in positions where
  visible=False.

  Args:
    visible: a boolean Tensor
    dtype: a dtype
  Returns:
    a Tensor with the given dtype and the same shape as "visible"
  """
  return mtf.cast(mtf.logical_not(visible), dtype) * -1e9
Example #27
0
def resnet34(x, classes_dim, float16=None, batch_norm=False):
    if float16:
        x = mtf.cast(x, dtype=tf.float16)
    logger.debug("[input tensor] (name,shape):({},{})".format(x.name, x.shape))
    x = backbone(x,
                 layerlist=[3, 4, 6, 3],
                 chalist=[64, 128, 256, 512],
                 strilist=[1, 2, 2, 2],
                 classes_dim=classes_dim,
                 blocklist=[BasicBlockWithDown, BasicBlock],
                 float16=float16,
                 batch_norm=batch_norm)
    return x
Example #28
0
def resnet152(x, classes_dim, float16=None, batch_norm=False):
    if float16:
        x = mtf.cast(x, dtype=tf.float16)
    logger.debug("[input tensor] (name,shape):({},{})".format(x.name, x.shape))
    x = backbone(x,
                 layerlist=[3, 8, 36, 3],
                 chalist=[256, 512, 1024, 2048],
                 strilist=[1, 2, 2, 2],
                 classes_dim=classes_dim,
                 blocklist=[ResidualBlockWithDown, ResidualBlock],
                 float16=float16,
                 batch_norm=batch_norm)
    return x
Example #29
0
    def _loss(self, logits, labels):
        with tf.variable_scope("loss_final"):
            loss_batch = self.loss_fn(logits=logits,
                                      targets=labels,
                                      vocab_dim=logits.shape[-1],
                                      z_loss=0.0)

        with tf.variable_scope("reduce_mean_final"):
            loss = mtf.reduce_mean(loss_batch)

        loss /= self.params.get("num_microbatches", 1)
        # Convert to train dtype
        loss = mtf.cast(loss, self.variable_dtype.slice_dtype)
        return loss, loss_batch  # loss batch must be returned for metric fns
Example #30
0
    def moe(self, x, layout, mesh_shape, input_mask, is_training):
        """Mixture of experts layer.

    TODO(noam): clean up the mixture-of-experts code in Transformer.

    Args:
      x: layer input
      layout: a mtf.LayoutRules
      mesh_shape: a mtf.Shape
      input_mask: a mtf.Tensor
      is_training: a boolean
    Returns:
      a mtf.Tensor (the layer output)
    """
        hparams = moe.HParams(
            moe_gating="top_2",
            moe_num_experts=self.config.moe_num_experts,
            moe_loss_coef=1e-3,
            moe_hidden_size=self.config.moe_intermediate_size,
            moe_group_size=2048,
            moe_capacity_factor_train=1.25,
            moe_capacity_factor_eval=8.0,
            moe_use_second_place_loss=False,
            moe_second_policy_train="random",
            moe_second_policy_eval="random",
            moe_second_threshold_train=0.2,
            moe_second_threshold_eval=0.2,
            moe_dropout_rate=0.0,
            moe_use_experts_attention=False,
            moe_min_expert_capacity=4)
        layer_output, loss = moe.transformer_moe_layer_v1(
            inputs=x,
            output_dim=self.model_dim,
            hparams=hparams,
            train=is_training,
            variable_dtype=tf.float32,
            layout=layout,
            mesh_shape=mesh_shape,
            nonpadding=(mtf.cast(input_mask, tf.float32)
                        if input_mask else None),
            activation=get_activation(
                self.config.feedforward_intermediate_act))
        self._extra_losses.append(loss)
        return layer_output
Example #31
0
 def _noisy_targets_from_spec(self, targets, noising_spec, losses=None):
   if noising_spec["type"] == "mask":
     # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0.
     return targets * mtf.cast(
         mtf.greater(mtf.random_uniform(targets.mesh, targets.shape),
                     noising_spec["prob"]), targets.dtype)
   elif noising_spec["type"] == "random_zipfian":
     # Replace a randomly-chosen noising_spec["prob"] of input tokens.
     # Rather than drawing the replacement tokens uniformly, we sample from
     #   a distribution favoring lower token-ids, assuming that the ids have
     #   been assigned in frequency order.  The probability of choosing an
     #   id is proportional to 1/(id+10)
     logits = mtf.log(1.0 / (mtf.range(
         targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0))
     logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape)
     r = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
     use_noise = mtf.less(
         mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"])
     return mtf.where(use_noise, r, targets)
   elif noising_spec["type"] == "transformer":
     # Train a small transformer to fill in masked out values, then
     # sample from it.
     hparams = self._hparams
     if hparams.mode != tf.estimator.ModeKeys.TRAIN:
       raise NotImplementedError("Not implemented")
     noiser_hparams = copy.copy(self._hparams)
     noiser_hparams.del_hparam("mode")
     noiser_hparams.override_from_dict(noising_spec["overrides"])
     with tf.variable_scope("noiser"):
       noiser = MtfTransformer(
           noiser_hparams,
           mode=hparams.mode,
           problem_hparams=self._problem_hparams)
       logits, loss = noiser._mtf_model_fn(  # pylint: disable=protected-access
           self._original_features, targets.mesh)
       samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
     losses.append(loss)
     return samples
   else:
     raise ValueError("unknown noising spec %s" % noising_spec)
Example #32
0
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "encdec":
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf.layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num, layer_type in enumerate(hparams.decoder_layers):
        if layer_type == "enc_att":
          with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num):
            q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
                mesh, self.heads_dim, self.model_dim,
                self.kv_dim, self.master_dtype, self.slice_dtype,
                self.activation_dtype)
            k = mtf.einsum(
                [encoder_output, k_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
            v = mtf.einsum(
                [encoder_output, v_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
          encdec_tensors.append((q_var, o_var, k, v))
        else:
          encdec_tensors.append(None)
      partial_targets = None
    elif hparams.transformer_type == "decoder":
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)
    else:
      raise ValueError(
          "hparams.model_type = %s not yet supported"
          % hparams.transformer_type)

    local_attention_window = mtf.Dimension(
        "local_attention_window", hparams.local_attention_window_size)
    if hparams.beam_size == 1:
      ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [self.heads_dim,
                                  local_attention_window, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [beam_dim, self.heads_dim,
                                  local_attention_window, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_states = []
    for layer in hparams.decoder_layers:
      if layer == "att":
        initial_states.extend(
            [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2)
      elif layer == "local_att":
        initial_states.extend(
            [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2)

    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_states = self._layer_stack(
            x,
            hparams.decoder_layers,
            encdec_attention_mask=encoder_attention_mask,
            step_num=step_num,
            encdec_tensors=encdec_tensors,
            states=states)
      logits = mtf.matmul(x, softmax_var)
      return logits, new_states

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf.beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if hparams.transformer_type == "encdec":
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf.beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu,
          dtype=self.activation_dtype)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)