예제 #1
0
def mnist_model(image, labels, mesh):
	"""The model.
	Args:
		image: tf.Tensor with shape [batch, 28*28]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
	batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
	rows_dim = mtf.Dimension("rows_size", image_height)
	cols_dim = mtf.Dimension("cols_size", image_width)
	channel_dim = mtf.Dimension("image_channel", num_channels)
	classes_dim = mtf.Dimension(name='classesnum',size=classesnum)
	x = mtf.import_tf_tensor(
		mesh, tf.reshape(image, [FLAGS.batch_size, image_height, image_width, num_channels]),
		mtf.Shape(
			[batch_dim, rows_dim, cols_dim, channel_dim]))
	# x = mtf.transpose(x, [batch_dim, rows_dim, cols_dim, channel_dim])
	# print(x.shape)
	logits = VGG(x, classes_dim=classes_dim,depth=depth)
	logits = mtf.cast(logits,dtype=tf.float32)

	if labels is None:
		loss = None
	else:
		labels = mtf.import_tf_tensor(
			mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
		loss = mtf.layers.softmax_cross_entropy_with_logits(
			logits, mtf.one_hot(labels, classes_dim), classes_dim)
		loss = mtf.reduce_mean(loss)
	return logits, loss
예제 #2
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """

    # tf_images is a tf.Tensor with shape [batch, 28, 28] and dtype tf.float32
    # tf_labels is a tf.Tensor with shape [batch] and dtype tf.int32
    batch_dim = mtf.Dimension("batch", 100)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    hidden_dim = mtf.Dimension("hidden", 1024)
    classes_dim = mtf.Dimension("classes", 10)
    images = mtf.import_tf_tensor(mesh,
                                  image,
                                  shape=[batch_dim, rows_dim, cols_dim])
    labels = mtf.import_tf_tensor(mesh, labels, [batch_dim])
    w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim])
    w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim])
    # einsum is a generalization of matrix multiplication (see numpy.einsum)
    hidden = mtf.relu(
        mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim]))
    logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim])
    loss = mtf.reduce_mean(
        mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim))

    return logits, loss
예제 #3
0
 def call(self, context, x, losses=None):
   """Call the layer."""
   memory_length = self.memory_length(context)
   q = self.compute_q(context, x)
   if context.mode == "incremental":
     m = x
   else:
     m = mtf.replace_dimensions(x, context.length_dim, memory_length)
   k = self.compute_k(context, m)
   v = self.compute_v(context, m)
   if context.mode == "incremental":
     one_hot = mtf.one_hot(
         context.position, memory_length, dtype=context.activation_dtype)
     inv_one_hot = 1.0 - one_hot
     old_k, old_v = context.get_states(2)
     k = old_k * inv_one_hot + k * one_hot
     v = old_v * inv_one_hot + v * one_hot
     memory_position = mtf.range(context.mesh, memory_length, tf.int32)
   else:
     memory_position = self.rename_length_to_memory_length(
         context.position, context)
   if context.mode == "incremental" or context.mode == "first_part":
     context.record_new_states([k, v])
   bias = self.compute_bias(context, memory_position, x,
                            self.softmax_heads_dims, q)
   return self.attention_internal(context, x, m, q, k, v, memory_length, bias)
예제 #4
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    classes_dim = mtf.Dimension("classes", 10)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image,
                                              [FLAGS.batch_size, 28, 28]),
                             [batch_dim, rows_dim, cols_dim])
    y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]),
                             [batch_dim])

    w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, classes_dim])
    b1 = mtf.get_variable(mesh, "b1", [classes_dim])

    logits = mtf.relu(mtf.einsum([x, w1], [batch_dim, classes_dim]) + b1)

    if labels is None:
        loss = None
    else:
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(y, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
예제 #5
0
 def body_fn(position, ids, *states):
   """One step in the decode loop."""
   context_incremental = Context(
       mesh=inputs.mesh,
       batch_dims=batch_dims,
       length_dim=length_dim,
       model_dim=self.model_dim,
       variable_dtype=variable_dtype,
       mode="incremental",
       autoregressive=self.autoregressive,
       position=position,
       states=states,
       new_states=[],
       sequence_id=sequence_id,
       encoder_output=encoder_output,
       encoder_sequence_id=encoder_sequence_id,
       constant_states=constant_states,
       shared_params=shared_params,
       layout=self.layout,
       mesh_shape=self.mesh_shape,
       encoder_layer_outputs=encoder_layer_outputs)
   inputs_this_step = mtf.gather(ids, position - 1, length_dim)
   with tf.variable_scope(self.name, reuse=True):
     logits = self._call_internal(context_incremental, inputs_this_step)
   ids_this_step = mtf.sample_with_temperature(
       logits, self.output_vocab_dim, temperature)
   new_position = position + 1
   new_ids = ids + ids_this_step * mtf.one_hot(
       position, length_dim, dtype=tf.int32)
   return [new_position, new_ids] + context_incremental.new_states
예제 #6
0
 def get_project_to_cluster_length(self, cluster_mask, dtype):
     """Returns projection from length dim to the shorter cluster length dim."""
     seq_length_dim = cluster_mask.shape.get_dim_by_name("length")
     cluster_length_dim = self.get_cluster_length_dim(seq_length_dim)
     return mtf.cast(cluster_mask, dtype) * mtf.one_hot(
         mtf.cumsum(mtf.cast(cluster_mask, tf.int32), seq_length_dim) - 1,
         output_dim=cluster_length_dim,
         dtype=dtype)
예제 #7
0
  def get_log_softmax_prefix(self, log_softmax, end_index):
    """Returns first end_index entries in log_softmax along the vocab dim."""
    prefix_dim = mtf.Dimension(self._vocab_dim.name, end_index)

    indices = mtf.mtf_range(
        log_softmax.mesh, dim=self._vocab_dim, dtype=tf.int32)
    prefix_indices = mtf.where(mtf.less(indices, end_index), indices, -1)
    projection = mtf.one_hot(
        prefix_indices, prefix_dim, dtype=log_softmax.dtype)

    return mtf.einsum([log_softmax, projection], reduced_dims=[self._vocab_dim])
예제 #8
0
 def call(self, context, x, losses=None):
   """Call the layer."""
   params = self.make_params(context)
   q = params.compute_q(x)
   memory_length = self.memory_length(context)
   if context.mode == "incremental":
     m = x
   else:
     m = mtf.replace_dimensions(x, context.length_dim, memory_length)
   if self.shared_kv:
     kv = params.compute_kv(m)
   else:
     k = params.compute_k(m)
     v = params.compute_v(m)
   if context.mode == "incremental":
     one_hot = mtf.one_hot(
         context.position, memory_length, dtype=context.activation_dtype)
     inv_one_hot = 1.0 - one_hot
     if self.shared_kv:
       old_kv = context.get_states(1)
       kv = old_kv * inv_one_hot + kv * one_hot
     else:
       old_k, old_v = context.get_states(2)
       k = old_k * inv_one_hot + k * one_hot
       v = old_v * inv_one_hot + v * one_hot
     memory_position = mtf.range(context.mesh, memory_length, tf.int32)
   else:
     memory_position = self.rename_length_to_memory_length(
         context.position, context)
   if context.mode == "incremental" or context.mode == "first_part":
     context.record_new_states([kv] if self.shared_kv else [k, v])
   if self.shared_kv:
     k = kv
     v = kv
   if self.attention_func == "hybrid":
     o = attention.hybrid_attention(
         q, k, v, context,
         memory_length,
         self.kv_dim,
         self.kv_dim,
         self.compute_bias(
             context, memory_position, x, params.query_heads_dims),
         **self.attention_kwargs_from_context(context))
   else:
     o = attention.attention(
         q, k, v,
         memory_length,
         self.kv_dim,
         self.kv_dim,
         self.compute_bias(
             context, memory_position, x, params.query_heads_dims),
         **self.attention_kwargs_from_context(context))
   return params.compute_output(o, output_shape=x.shape)
예제 #9
0
def mnist_model(image, labels, mesh, hs_t):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh
    hs_t: a mtf.Tensor with shape [batch, hidden_1]
  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
    hs_t: an updated mtf.Tensor
  """
    input_num = 28
    timesteps_num = 28
    classes_num = 10

    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    input_dim = mtf.Dimension("input", input_num)
    timesteps_dim = mtf.Dimension("timesteps", timesteps_num)
    classes_dim = mtf.Dimension("classes", classes_num)
    hidden_dim_1 = mtf.Dimension("hidden_1", FLAGS.hidden_size)
    hidden_dim_2 = mtf.Dimension("hidden_2", FLAGS.hidden_size)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image,
                                              [FLAGS.batch_size, 28, 28]),
                             [batch_dim, timesteps_dim, input_dim])
    y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]),
                             [batch_dim])
    hs_t = mtf.import_tf_tensor(mesh, hs_t, [batch_dim, hidden_dim_1])

    Wxh = mtf.get_variable(mesh, "Wxh", [input_dim, hidden_dim_2])
    Whh = mtf.get_variable(mesh, "Whh", [hidden_dim_1, hidden_dim_2])
    Why = mtf.get_variable(mesh, "Why", [hidden_dim_2, classes_dim])
    bh = mtf.get_variable(mesh, "bh", [hidden_dim_2])
    by = mtf.get_variable(mesh, "by", [classes_dim])

    x_list = mtf.unstack(x, timesteps_dim)

    for xs_t in x_list:
        hs_t = mtf.tanh(
            mtf.einsum([xs_t, Wxh], [batch_dim, hidden_dim_2]) +
            mtf.einsum([hs_t, Whh], [batch_dim, hidden_dim_2]) + bh)
        logits = mtf.einsum([hs_t, Why], [batch_dim, classes_dim]) + by

    if labels is None:
        loss = None
    else:
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(y, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss, hs_t
예제 #10
0
    def call(self, context, x, losses=None):
        """Call the layer."""
        wq, wk, wv, wo = mtf.layers.multihead_attention_params(
            context.mesh, self.heads_dim, context.model_dim, self.kv_dim,
            context.variable_dtype)
        memory_length = mtf.Dimension("memory_length", context.length_dim.size)
        q = mtf.einsum([x, wq], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            m = x
        else:
            m = mtf.rename_dimension(x, context.length_dim.name,
                                     "memory_length")
        k = mtf.einsum([m, wk], reduced_dims=[context.model_dim])
        v = mtf.einsum([m, wv], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            old_k, old_v = context.get_states(2)
            one_hot = mtf.one_hot(context.position,
                                  memory_length,
                                  dtype=context.activation_dtype)
            inv_one_hot = 1.0 - one_hot
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot
        if context.mode == "incremental" or context.mode == "first_part":
            context.record_new_states([k, v])
        masks = []
        if context.autoregressive:
            masks.append(
                mtf.cast(
                    mtf.less(
                        context.position,
                        mtf.range(context.mesh, memory_length,
                                  dtype=tf.int32)), context.activation_dtype) *
                -1e9)
        if (context.sequence_id is not None
                and isinstance(context.sequence_id, mtf.Tensor)
                and context.length_dim in context.sequence_id.shape):
            masks.append(
                mtf.cast(
                    mtf.not_equal(
                        context.sequence_id,
                        mtf.layers.rename_length_to_memory_length(
                            context.sequence_id)), context.activation_dtype) *
                -1e9)
        mask = mtf.add_n(masks) if masks else None

        o = mtf.layers.dot_product_attention_v2(
            q, k, v, memory_length, self.kv_dim, self.kv_dim, mask,
            self.dropout_rate if context.train else 0.0, [context.length_dim])
        return mtf.einsum([o, wo],
                          x.shape,
                          reduced_dims=[self.heads_dim, self.kv_dim])
예제 #11
0
    def compute_loss(self, decoder, hidden, targets, context):
        """Computes the loss during training."""
        logits = self._embedding.hidden_to_logits(hidden, context=context)
        soft_targets = mtf.one_hot(targets - self._start_token_id,
                                   self._vocab_dim,
                                   dtype=context.activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self._vocab_dim, z_loss=decoder.z_loss)

        padding_mask = mtf.layers.weights_nonzero(
            targets, dtype=context.activation_dtype)

        return (mtf.reduce_sum(loss * padding_mask) /
                decoder.loss_denominator(targets, context.num_microbatches))
예제 #12
0
 def call(self, context, x, losses=None):
   """Call the layer."""
   memory_length = self.memory_length(context)
   q = self.compute_q(context, x)
   if context.mode == "incremental":
     m = x
   else:
     m = mtf.replace_dimensions(x, context.length_dim, memory_length)
   if context.mode == "incremental":
     one_hot = mtf.one_hot(
         context.position, memory_length, dtype=context.activation_dtype)
     inv_one_hot = 1.0 - one_hot
     old_m, = context.get_states(1)
     m = old_m * inv_one_hot + one_hot * m
     memory_position = mtf.range(context.mesh, memory_length, tf.int32)
   else:
     memory_position = self.rename_length_to_memory_length(
         context.position, context)
   if context.mode == "incremental" or context.mode == "first_part":
     context.record_new_states([m])
   bias = self.compute_bias(context, memory_position, x, self.heads_dims, q)
   return self.attention_internal(context, q, m, memory_length, bias)
예제 #13
0
파일: utils.py 프로젝트: zxhjiutian/gpt-neo
def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
    if targets.dtype.is_integer:
        # hard targets
        if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])):
            raise ValueError(
                "softmax_cross_entropy_with_logits with hard targets "
                "dims in targets=%s should be dims in logits=%s other than "
                "vocab_dim=%s" % (targets, logits, vocab_dim))
        targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype)
    elif set(targets.shape.dims) != set(logits.shape.dims):
        raise ValueError(
            "softmax_cross_entropy_with_logits with soft targets "
            "dims in targets=%s should be dims in logits=%s" % (targets, logits))

    if vocab_dim not in logits.shape.dims:
        raise ValueError("vocab_dim must be in logits.shape.dims")

    log_entmax = mtf.log(entmax(logits, dim=vocab_dim))

    loss = mtf.negative(
        mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim))

    return loss
예제 #14
0
def model_backbone(image, labels, mesh):
	"""The model.
	Args:
		image: tf.Tensor with shape [batch, 32*32]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
	batch_dim = mtf.Dimension("batch", args_opt.batch_size)
	rows_dim = mtf.Dimension("rows_size", 224)
	cols_dim = mtf.Dimension("cols_size", 224)
	channel_dim = mtf.Dimension("image_channel", 3)
	classes_dim = mtf.Dimension(name='classesnum',size=args_opt.class_num)
	x = mtf.import_tf_tensor(
		mesh, tf.reshape(image, [args_opt.batch_size, 224, 224, 3]),
		mtf.Shape(
			[batch_dim, rows_dim, cols_dim, channel_dim]))
	if args_opt.fp16:
		float16=mtf.VariableDType(tf.float16,tf.float16,tf.float16)
	else:
		float16=None

	logits = network[args_opt.model](x, classes_dim=classes_dim,float16=float16,batch_norm=False if 'vgg' in args_opt.model else True)
	logits = mtf.cast(logits,dtype=tf.float32)

	if labels is None:
		loss = None
	else:
		labels = mtf.import_tf_tensor(
			mesh, tf.reshape(labels, [args_opt.batch_size]), mtf.Shape([batch_dim]))
		loss = mtf.layers.softmax_cross_entropy_with_logits(
			logits, mtf.one_hot(labels, classes_dim), classes_dim)
		loss = mtf.reduce_mean(loss)
	return logits, loss
예제 #15
0
 def compute_loss(logits, positions):
   one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim)
   log_probs = mtf.log_softmax(logits, seq_dim)
   loss = -mtf.reduce_mean(
       mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim))
   return loss
예제 #16
0
def _top_2_gating(inputs,
                  outer_expert_dims,
                  experts_dim,
                  expert_capacity_dim,
                  hparams,
                  train,
                  variable_dtype,
                  importance=None,
                  name="top_2_gating"):
    """Compute gating for mixture-of-experts in TensorFlow.

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_use_second_place_loss: a boolean
    hparams.moe_second_policy_train: a string
    hparams.moe_second_policy_eval: a string
    hparams.moe_second_threshold: a float

  The returned forward assignment is a tensor used to map (via einsum) from the
  inputs to the expert_inputs.  Likewise, the returned combine_tensor is
  used to map (via einsum) from the expert outputs to the outputs.  Both the
  forward and backward assignments are mostly zeros.  The shapes of the tensors
  are as follows.

  inputs: [<batch_dims>, group_size_dim, input_dim]
  importance: [<batch_dims>, group_size_dim]
  dispatch_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  expert_inputs:
    [<batch_dims>, experts_dim, expert_capacity_dim, input_dim]

  expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim]
  combine_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  outputs: [<batch_dims>, group_size_dim, output_dim]

  "importance" is an optional tensor with one floating-point value for each
  input vector.  If the importance of an input is 1.0, then we send it to
  up to 2 experts.  If 0.0 < importance < 1.0, then we send it to at most
  one expert.  If importance == 0.0, then we send it to no experts.

  We use "importance" at the second-level gating function of a hierarchical
  mixture of experts.  Inputs to the first-choice expert-group get importance
  1.0.  Inputs to the second-choice expert group get importance 0.5.
  Inputs that represent padding get importance 0.0.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim]
    outer_expert_dims: an optional list of dimensions.  This is for the case
      where we are at an inner level of a hierarchical MoE.
    experts_dim: a Dimension (the number of experts)
    expert_capacity_dim: a Dimension (number of examples per group per expert)
    hparams: model hyperparameters.
    train: a boolean
    variable_dtype: a mtf.VariableDType
    importance: an optional tensor with shape [<batch_dims>, group_size_dim]
    name: an optional string

  Returns:
    dispatch_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    combine_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on illegal hyperparameters
  """
    group_size_dim, unused_input_dim = inputs.shape.dims[-2:]

    raw_gates = mtf.layers.dense(inputs,
                                 experts_dim,
                                 use_bias=False,
                                 expert_dims=outer_expert_dims,
                                 variable_dtype=variable_dtype,
                                 name=name)
    raw_gates = mtf.softmax(raw_gates, experts_dim)

    # The internals of this function run in float32.
    #   bfloat16 seems to reduce quality.
    raw_gates = mtf.to_float(raw_gates)

    expert_capacity_f = float(expert_capacity_dim.size)

    # FIND TOP 2 EXPERTS PER POSITON
    # Find the top expert for each position. shape=[batch, group]
    index_1, gate_1 = mtf.top_1(raw_gates, experts_dim)
    # [batch, group, experts]
    mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
    density_1_proxy = raw_gates
    if importance is not None:
        mask_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        gate_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))
    gates_without_top_1 = raw_gates * (1.0 - mask_1)
    # [batch, group]
    index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim)
    # [batch, group, experts]
    mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)
    if importance is not None:
        mask_2 *= mtf.to_float(mtf.greater(importance, 0.0))

    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

    # BALANCING LOSSES
    # shape = [batch, experts]
    # We want to equalize the fraction of the batch assigned to each expert
    density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
    # Something continuous that is correlated with what we want to equalize.
    density_1_proxy = mtf.reduce_mean(density_1_proxy,
                                      reduced_dim=group_size_dim)
    loss = (mtf.reduce_mean(density_1_proxy * density_1) *
            float(experts_dim.size * experts_dim.size))

    if hparams.moe_use_second_place_loss:
        # Also add a loss to encourage all experts to be used equally also as the
        # second-place expert.  Experimentally, this seems to be a wash.
        # We want to equalize the fraction of the batch assigned to each expert:
        density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim)
        # As a proxy for density_2, we renormalize the raw gates after the top one
        # has been removed.
        normalized = gates_without_top_1 / (mtf.reduce_sum(
            gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
        density_2_proxy = mtf.reduce_mean(normalized,
                                          reduced_dim=group_size_dim)
        loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) *
                  float(experts_dim.size * experts_dim.size))
        loss += loss_2 * 0.5

    # Depending on the policy in the hparams, we may drop out some of the
    # second-place experts.
    if train:
        policy = hparams.moe_second_policy_train
        threshold = hparams.moe_second_threshold_train
    else:
        policy = hparams.moe_second_policy_eval
        threshold = hparams.moe_second_threshold_eval
    if policy == "all":
        # Use second-place experts for all examples.
        pass
    elif policy == "none":
        # Never use second-place experts for all examples.
        mask_2 = mtf.zeros_like(mask_2)
    elif policy == "threshold":
        # Use second-place experts if gate_2 > threshold.
        mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
    elif policy == "random":
        # Use second-place experts with probablity min(1.0, gate_2 / threshold).
        mask_2 *= mtf.to_float(
            mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
                     gate_2 / max(threshold, 1e-9)))
    else:
        raise ValueError("Unknown policy %s" % policy)

    # COMPUTE ASSIGNMENT TO EXPERTS
    # [batch, group, experts]
    # This is the position within the expert's mini-batch for this sequence
    position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim,
                                      exclusive=True) * mask_1
    # Remove the elements that don't fit. [batch, group, experts]
    mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
    # [batch, experts]
    # How many examples in this sequence go to this expert
    mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim)
    # [batch, group] - mostly ones, but zeros where something didn't fit
    mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
    # [batch, group]
    position_in_expert_1 = mtf.reduce_sum(position_in_expert_1,
                                          reduced_dim=experts_dim)
    # Weight assigned to first expert.  [batch, group]
    gate_1 *= mask_1_flat

    # [batch, group, experts]
    position_in_expert_2 = (
        mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count)
    position_in_expert_2 *= mask_2
    mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
    # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    gate_2 *= mask_2_flat
    position_in_expert_2 = mtf.reduce_sum(position_in_expert_2,
                                          reduced_dim=experts_dim)

    # [batch, group, experts, expert_capacity]
    combine_tensor = (
        gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
        gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

    combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
    loss = mtf.cast(loss, inputs.dtype)

    dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool),
                               combine_tensor.dtype)

    return dispatch_tensor, combine_tensor, loss
예제 #17
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 4)
    col_blocks_dim = mtf.Dimension("col_blocks", 4)
    rows_dim = mtf.Dimension("rows_size", 7)
    cols_dim = mtf.Dimension("cols_size", 7)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    fh_dim = mtf.Dimension("fh", 9)
    fw_dim = mtf.Dimension("fw", 9)
    filters1_dim = mtf.Dimension("filters1", 16)
    filters2_dim = mtf.Dimension("filters2", 16)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(
        mtf.conv2d_with_blocks(x,
                               kernel1,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    f2 = mtf.relu(
        mtf.conv2d_with_blocks(f1,
                               kernel2,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    #hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)
    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-4:],
                          activation=mtf.relu,
                          name="hidden1")
    #h2 = mtf.layers.dense(
    #    h1, hidden_dim2,
    #    activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h1, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
예제 #18
0
def Alexnet(img, labels, num_nodes, num_gpus, args):
    num_classes = 1000
    keep_prob = 0.5
    learning_rate = 0.01
    graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes(
        img, labels, num_nodes, num_gpus, args)
    RenameFC = lambda x: mt.rename_dimension(x, x.shape[-1].name,
                                             utils.RandName())

    strategy = args.strategy
    if strategy == 0:
        fc6_units = mtf.Dimension(utils.RandName(), 4096)
        fc7_units = mtf.Dimension(utils.RandName(), 4096)
        fc8_units = mtf.Dimension(utils.RandName(), num_classes)

    elif strategy == 1:
        fc6_units = mtf.Dimension('axis1', 4096)
        fc7_units = mtf.Dimension('axis0', 4096)
        fc8_units = mtf.Dimension('axis1', num_classes)

    elif strategy == 2:
        num_classes = utils.RoundUp(num_classes, num_gpus)
        fc6_units = mtf.Dimension('axis0', 4096)
        fc7_units = mtf.Dimension('axis0', 4096)
        fc8_units = mtf.Dimension('axis0', num_classes)

    elif strategy == 3:
        num_classes = utils.RoundUp(num_classes, num_gpus // 2)
        fc6_units = mtf.Dimension('axis1', 4096)
        fc7_units = mtf.Dimension('axis1', 4096)
        fc8_units = mtf.Dimension('axis1', num_classes)

    with tf.variable_scope('alexnet'):
        # Conv1 + ReLU + maxpool1
        conv1 = mt.Conv2d(mtf_img,
                          GetFilterShape(mtf_img, (11, 11, 3, 96)), (4, 4),
                          'VALID',
                          activation=mtf.relu,
                          name='conv1')
        pool1 = mt.MaxPool(conv1, (3, 3), (2, 2), 'VALID', name='pool1')

        # Conv2 + ReLU + maxpool2
        conv2 = mt.Conv2d(pool1,
                          GetFilterShape(pool1, (5, 5, 96, 256)), (1, 1),
                          'SAME',
                          activation=mtf.relu,
                          name='conv2')
        pool2 = mt.MaxPool(conv2, (3, 3), (2, 2), name='pool2')

        # Conv3 + ReLU
        conv3 = mt.Conv2d(pool2,
                          GetFilterShape(pool2, (3, 3, 256, 384)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv3')

        # Conv4 + ReLU
        conv4 = mt.Conv2d(conv3,
                          GetFilterShape(conv3, (3, 3, 384, 384)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv4')

        # Conv5 + ReLU + maxpool5
        conv5 = mt.Conv2d(conv4,
                          GetFilterShape(conv4, (3, 3, 384, 256)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv5')
        pool5 = mt.MaxPool(conv5, (3, 3), (2, 2), name='pool5')

        # Rename dims
        if strategy == 1:
            k_dim = mtf.Dimension(utils.RandName(),
                                  utils.Prod(pool5.shape.to_integer_list[1:]))
            pool5 = mtf.reshape(pool5, mtf.Shape([pool5.shape[0], k_dim]))
            pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1],
                                                   (utils.RandName(), 'axis0'))

        elif strategy == 2:
            pool5 = mt.rename_dimension(pool5, pool5.shape[0].name,
                                        utils.RandName())

        elif strategy == 3:
            assert pool5.shape[0].name == 'axis0'
            #dim_names = pool5.shape.rename_dimension('axis0', utils.RandName())
            #pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], dim_names)
            pool5 = ReplaceMeshWithConcatSplit(pool5, meshes[1])

        # FC + ReLU + dropout
        fc_activation = lambda x: mtf.dropout(mtf.relu(x), keep_prob)
        fc6 = mtf.layers.dense(pool5,
                               fc6_units,
                               activation=fc_activation,
                               reduced_dims=pool5.shape[1:],
                               name='fc6')
        if strategy == 2:
            fc6 = RenameFC(fc6)
        elif strategy == 3:
            fc6 = RenameFC(fc6)

        fc7 = mtf.layers.dense(fc6,
                               fc7_units,
                               activation=fc_activation,
                               reduced_dims=fc6.shape.dims[-1:],
                               name='fc7')
        if strategy == 2:
            fc7 = RenameFC(fc7)
        elif strategy == 3:
            fc7 = RenameFC(fc7)

        fc8 = mtf.layers.dense(fc7,
                               fc8_units,
                               reduced_dims=fc7.shape.dims[-1:],
                               name='fc8')
        fc8 = mtf.dropout(fc8, keep_prob)

        if strategy == 1:
            assert fc8.shape[-1].name == 'axis1'
            fc8 = ReplaceMeshWithDuplicates(fc8, meshes[2])

    with tf.variable_scope('loss'):
        if fc8.shape[0] != mtf_labels.shape[0]:
            fc8 = mt.rename_dimension(fc8, fc8.shape[0].name,
                                      mtf_labels.shape[0].name)
        one_hot_labels = mtf.one_hot(mtf_labels, fc8.shape[-1])
        mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits(
            fc8, one_hot_labels, fc8.shape[-1])
        mtf_loss = mtf.reduce_mean(mtf_cross_ent)

    return graph, mesh_to_impl, mtf_loss
예제 #19
0
    def body_fn(position, ids, *states):
        """One step in the decode loop."""
        nonlocal sampling_keep_top_k

        context = mtf_transformer.transformer.Context(
            model=None,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="incremental",
            position=position,
            position_is_default=True,
            states=states,
            new_states=[],
            initial_position=position,
            sequence_id=None,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=ids,
            encoder_inputs=encoder_inputs) if not slow_sampling else None

        with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
            logits, _, _ = gpt2.model({"inputs": ids},
                                      other_features,
                                      params,
                                      inputs.mesh,
                                      variable_dtype=variable_dtype,
                                      context=context)

        # By default, do top_k sampling of 0.9
        if sampling_keep_top_k == -2:
            sampling_keep_top_k = int(logits.shape[-1].size * 0.1)

        if sampling_keep_top_k != -1:
            if sampling_keep_top_k <= 0:
                raise ValueError(
                    "sampling_keep_top_k must either be -1 or positive.")
            k_largest = mtf.nth_largest_element(
                logits,
                n=sampling_keep_top_k,
                reduced_dim=other_features["vocab_dim"])
            logits = mtf.where(mtf.less_equal(logits, k_largest),
                               mtf.ones_like(logits) * -1e6, logits)

        ids_this_step = mtf.sample_with_temperature(
            logits, other_features["vocab_dim"], temperature)

        if slow_sampling:
            ids_this_step = mtf.shift(ids_this_step,
                                      offset=1,
                                      dim=length_dim,
                                      wrap=False)
        else:
            ids_this_step = mtf.reshape(ids_this_step, (batch_dims))

        one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
        one_new_id = ids_this_step * one_hot
        new_ids = (1 - one_hot) * ids + one_new_id
        new_position = position + 1

        ret = [new_position, new_ids]
        if context is not None:
            ret += context.new_states
        return ret
예제 #20
0
    def _mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        hparams = self._hparams
        targets = tf.to_int32(features["targets"])
        if len(targets.get_shape()) > 2:
            tf.logging.info("targets = %s" % targets)
            targets = tf.squeeze(targets, [2, 3])
        # pad targets to max_length
        def pad_to_max_length(x):
            extra_length = hparams.max_length - tf.shape(x)[1]
            x = tf.pad(x, [[0, 0], [0, extra_length]])
            x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
            return x

        targets = pad_to_max_length(targets)
        for key in [
                "targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"
        ]:
            if key in features:
                features[key] = pad_to_max_length(features[key])
        shifted_targets = common_layers.shift_right_2d(targets)

        targets = self._import_to_batch_by_length(targets, "targets", mesh,
                                                  hparams)
        shifted_targets = self._import_to_batch_by_length(
            shifted_targets, "shifted_targets", mesh, hparams)

        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = self._import_to_batch_by_length(
                features["targets_segmentation"], "targets_segmentation", mesh,
                hparams)
            targets_position = self._import_to_batch_by_length(
                features["targets_position"], "targets_position", mesh,
                hparams)
            decoder_self_attention_mask = (
                mtf.layers.attention_mask_autoregressive(
                    targets_position, dtype=self.activation_dtype) +
                mtf.layers.attention_mask_same_segment(
                    targets_segmentation, dtype=self.activation_dtype))
        else:
            targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
            decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
                targets_position, dtype=self.activation_dtype)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

        extra_losses = []
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if hparams.transformer_type == "decoder":
            encoder_output = None
            encoder_decoder_attention_mask = None
        else:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = pad_to_max_length(inputs)
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            if "inputs_segmentation" in features:
                # "Packed" dataset - keep the examples from seeing each other.
                inputs_segmentation = self._import_to_batch_by_length(
                    features["inputs_segmentation"], "inputs_segmentation",
                    mesh, hparams)
                inputs_position = self._import_to_batch_by_length(
                    features["inputs_position"], "inputs_position", mesh,
                    hparams)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        inputs_segmentation, dtype=self.activation_dtype))
            else:
                inputs_position = mtf.range(mesh,
                                            self.length_dim,
                                            dtype=tf.int32)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_ignore_padding(
                        inputs, dtype=self.activation_dtype))

            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.gather(positional_embedding_var, inputs_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.encoder_layers,
                    self_attention_mask=encoder_self_attention_mask,
                    losses=extra_losses)

        if hparams.transformer_type == "encdec":
            if "inputs_segmentation" in features:
                encoder_decoder_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        targets_segmentation,
                        inputs_segmentation,
                        dtype=self.activation_dtype))
            else:
                encoder_decoder_attention_mask = encoder_self_attention_mask
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)

        if hparams.transformer_type != "encoder":
            # DECODER
            x = (mtf.gather(targets_embedding_var, shifted_targets,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, targets_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("decoder"):
                x = self._layer_stack(
                    x,
                    hparams.decoder_layers,
                    encoder_output=encoder_output,
                    self_attention_mask=decoder_self_attention_mask,
                    encdec_attention_mask=encoder_decoder_attention_mask,
                    losses=extra_losses)
        logits = mtf.matmul(x, softmax_var)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
        off_value = hparams.label_smoothing / self._targets_vocab_size
        on_value = 1.0 - hparams.label_smoothing + off_value
        soft_targets = mtf.one_hot(targets,
                                   self.targets_vocab_dim,
                                   on_value=on_value,
                                   off_value=off_value,
                                   dtype=self.activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.targets_vocab_dim)
        weights = mtf.layers.weights_nonzero(targets,
                                             dtype=self.activation_dtype)
        loss = mtf.reduce_mean(loss * weights)
        for l in extra_losses:
            loss += l
        logits = mtf.to_float(logits)
        # combine batch dims
        if len(self.batch_dims) > 1:
            combined_batch_dim = mtf.Dimension(self.batch_dims[0].name,
                                               mtf.Shape(self.batch_dims).size)
            logits = mtf.reshape(logits,
                                 [combined_batch_dim] + logits.shape.dims[-2:])
        return logits, loss
예제 #21
0
파일: gpt2.py 프로젝트: doinker/GPTNeo
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None):
    # x :: [batch, seq, n_embd]
    x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh

    # n_state is the same as config["n_embd"], which is also the same as dim_embd.
    assert n_state.size % params["n_head"] == 0

    dim_heads = mtf.Dimension("heads", params["n_head"])

    num_mem_kv = params.get("num_mem_kv", 0)
    use_num_mem_kv = num_mem_kv > 0

    with tf.variable_scope(scope):
        # Compute attention inputs
        dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"])
        mtfparams = mtf.transformer.attention.attention_params_simple(
            x.mesh,
            io_dim=dim_embd,
            kv_dim=dim_kv,
            heads_dim=dim_heads,
            variable_dtype=variable_dtype
        )
        q = mtfparams.compute_q(x)
        k = mtfparams.compute_k(x)
        v = mtfparams.compute_v(x)

        if is_incremental_inference(context):
            one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype)
            inv_one_hot = 1.0 - one_hot
            old_k, old_v = context.get_states(2)
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot

        if exists(context):
            context.record_new_states([k, v])

        with tf.variable_scope("attention"):
            if attention_type == "local":
                # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
                radius = params.get("local_attention_radius", 256)

                if is_incremental_inference(context):
                    q *= one_hot

                a = mtf_transformer.attention.local_attention_1d(
                    q, k, v,
                    length_dim=k.shape[1],
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    radius=radius,
                    length_dim_num_splits=1,
                    fully_autoregressive=params["causal"],
                    attention_kwargs={},
                )

                if is_incremental_inference(context):
                    a = mtf.gather(a, context.position - 1, dim_seq)

            elif attention_type == "global":

                # TODO: pass in fake context
                # Broadcast mask bias across batch and heads
                if exists(bias):
                    if not is_incremental_inference(context):
                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]])
                    else:
                        # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
                        bias = mtf.gather(bias, context.position - 1, dim_seq)
                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]])

                # memory key / values, from all-attention paper
                if use_num_mem_kv:
                    k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh)

                k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim)
                v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim)

                attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0

                a = mtf_transformer.attention.attention(
                    q, k, v,
                    memory_length_dim=memory_length_dim,
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    bias=broadcasted_bias,
                    dropout_rate=attn_dropout_rate
                )

            elif attention_type == "linear":
                linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention
                a = linear_attn_fn(q, k, v)

            else:
                raise NotImplementedError("Unknown attention type {}!".format(attention_type))

        with tf.variable_scope("compute_output"):
            a = mtfparams.compute_output(a, x_shape)

        with tf.variable_scope("compute_output_bias"):
            b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0),
                                 master_dtype=variable_dtype.master_dtype,
                                 slice_dtype=variable_dtype.slice_dtype,
                                 activation_dtype=variable_dtype.activation_dtype)
            a += b

        if params["mode"] == "train" and params["res_dropout"] > 0:
            a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout")
        return a
예제 #22
0
  def _call_internal(self, context, inputs, targets=None):
    """Compute logits based on inputs (all positions in parallel).

    Also updates context if applicable.

    Args:
      context: a Context
      inputs: a Tensor
      targets: an optional Tensor

    Returns:
      logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim]
    """
    mesh = inputs.mesh
    if "embedding" in context.shared_params:
      embedding_weights = context.shared_params["embedding"]
    else:
      embedding_weights = mtf.layers.embedding_weights(
          mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype,
          name="embedding")
    x = mtf.gather(embedding_weights, inputs, self.input_vocab_dim)
    if "positional_embedding" in context.shared_params:
      pos_emb_var = context.shared_params["positional_embedding"]
    else:
      pos_emb_var = mtf.layers.embedding_weights(
          mesh, self.max_length_dim, self.model_dim, context.variable_dtype,
          "positional_embedding")
    if context.position_is_default:
      pos_emb = mtf.rename_dimension(
          mtf.slice(pos_emb_var, 0, context.length_dim.size,
                    self.max_length_dim.name),
          self.max_length_dim.name, context.length_dim.name)
    else:
      pos_emb = mtf.gather(
          pos_emb_var, context.position, self.max_length_dim,
          output_shape=x.shape)
    x += pos_emb
    x = self.layer_stack.call(context, x)
    if self.output_vocab_dim is None:
      return x
    if self.shared_embedding_and_softmax_weights:
      logits = mtf.einsum(
          [x * (self.model_dim.size ** -0.5), embedding_weights],
          reduced_dims=[self.model_dim])
    else:
      logits = mtf.layers.dense(
          x, self.output_vocab_dim, use_bias=False,
          variable_dtype=context.variable_dtype,
          name="logits")
    if targets is not None and context.losses is not None:
      off_value = self.label_smoothing / self.output_vocab_dim.size
      on_value = 1.0 - self.label_smoothing + off_value
      soft_targets = mtf.one_hot(
          targets, self.output_vocab_dim,
          dtype=context.activation_dtype,
          on_value=on_value,
          off_value=off_value)
      loss = mtf.layers.softmax_cross_entropy_with_logits(
          logits, soft_targets, self.output_vocab_dim,
          z_loss=self.z_loss if context.train else 0.0)
      weights = mtf.layers.weights_nonzero(
          targets, dtype=context.activation_dtype)
      loss = mtf.reduce_mean(loss * weights)
      context.losses.append(loss)
    return logits
예제 #23
0
def Inception(img, labels, num_nodes, num_gpus, args):
    num_classes = 1000
    graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes(
        args, img, labels, num_nodes, num_gpus)

    strategy = args.strategy
    with tf.variable_scope('inception'):
        conv1a = BasicConv(mtf_img, (3, 3, 3, 32), stride=2, name='conv1a')
        conv2a = BasicConv(conv1a, (3, 3, 32, 32), name='conv2a')
        conv2b = BasicConv(conv2a, (3, 3, 32, 64),
                           padding='SAME',
                           name='conv2b')
        pool = MaxPool(conv2b, (3, 3), stride=2, name='pool1')
        conv3b = BasicConv(pool, (1, 1, 64, 80), name='conv3b')
        conv4a = BasicConv(conv3b, (3, 3, 80, 192), name='conv4a')
        pool = MaxPool(conv4a, (3, 3), stride=2, name='pool2')

        mixed5b = InceptionA(pool, 192, 32, name='mixed5b')
        mixed5c = InceptionA(mixed5b, 256, 64, name='mixed5c')
        mixed5d = InceptionA(mixed5c, 288, 64, name='mixed5d')

        mixed6a = InceptionB(mixed5d, 288, name='mixed6a')

        mixed6b = InceptionC(mixed6a, 768, 128, name='mixed6b')
        mixed6c = InceptionC(mixed6b, 768, 160, name='mixed6c')
        mixed6d = InceptionC(mixed6c, 768, 160, name='mixed6d')
        mixed6e = InceptionC(mixed6d, 768, 192, name='mixed6e')

        mixed7a = InceptionD(mixed6e, 768, name='mixed7a')

        mixed7b = InceptionE(mixed7a, 1280, strategy, meshes, name='mixed7b')
        mixed7c = InceptionE(mixed7b, 2048, strategy, name='mixed7c')

        mean = mtf.reduce_mean(mixed7c,
                               output_shape=mtf.Shape(
                                   [mixed7c.shape[0], mixed7c.shape[-1]]))

        assert mean.shape[0].name == 'axis0' \
                and not mean.shape[1].name.startswith('axis')
        if strategy == 1:
            shape = mean.shape
            shape = shape.rename_dimension(shape[0].name,
                                           mtf_labels.shape[0].name)
            shape = shape.rename_dimension(shape[1].name, 'axis0')
            with tf.variable_scope('reshape_mean'):
                mean = mtf.reshape(mean, shape)
            dim_name = 'axis1'

        elif strategy == 2:
            num_classes = utils.RoundUp(num_classes, num_gpus)
            mean = mt.rename_dimension(mean, 'axis0', mtf_labels.shape[0].name)
            dim_name = 'axis0'

        elif strategy == 3:
            num_classes = utils.RoundUp(num_classes, num_gpus)
            assert mean.shape[0].name == 'axis0'
            dim_names = mean.shape.rename_dimension('axis0',
                                                    mtf_labels.shape[0].name)
            mean = ReplaceMeshWithConcatSplit(mean, meshes[1], dim_names)
            dim_name = 'axis1'

        else:
            dim_name = utils.RandName()
        fc = mtf.layers.dense(mean,
                              mtf.Dimension(dim_name, num_classes),
                              reduced_dims=mean.shape[-1:])

        with tf.variable_scope('loss'):
            assert mtf_labels.mesh == fc.mesh
            assert mtf_labels.shape[0] == fc.shape[0]
            one_hot_labels = mtf.one_hot(mtf_labels, fc.shape[-1])
            cross_ent = mtf.layers.softmax_cross_entropy_with_logits(
                fc, one_hot_labels, fc.shape[-1])
            loss = mtf.reduce_mean(cross_ent)

        return graph, mesh_to_impl, loss
예제 #24
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.activation_type

        # We assume fixed vocab size for targets
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim,
                                                   self.length_dim]),
                                        name=name)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape([batch_dim, self.model_dim]))

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([self.targets_vocab_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       self.targets_vocab_dim)
        # Add positional embeddings
        x += mtf.reshape(self.create_positional_emb_2d(targets),
                         [self.length_dim, self.model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf.layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    self.inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_decoder_layers):
            layer_name = "decoder_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Self attention layer
                x += layer_prepostprocess_dropout(
                    mtf.layers.masked_local_attention_1d(
                        mtf.layers.layer_norm(x,
                                              self.model_dim,
                                              name="layer_norm_att"),
                        None,
                        self.kv_dim,
                        self.heads_dim,
                        block_length=hparams.block_length,
                        name="self_att"))
                # ffn layer
                x += layer_prepostprocess_dropout(
                    mtf.layers.dense_relu_dense(
                        mtf.layers.layer_norm(x,
                                              self.model_dim,
                                              name="layer_norm_ffn"),
                        self.feedforward_dim,
                        hparams.dropout,
                        dropout_broadcast_dims=[self.length_dim]))

        x = mtf.layers.layer_norm(x, self.model_dim, name="final_layer_norm")

        # Calculate the logits and loss.
        logits = mtf.layers.dense(x, self.outputs_vocab_dim, name="logits")
        soft_targets = mtf.one_hot(targets,
                                   self.outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.outputs_vocab_dim)
        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l

        # Reshape logits to original target shape.
        logits = mtf.reshape(
            logits,
            mtf.Shape([
                batch_dim, self.rows_dim, self.orig_cols_dim,
                self.channels_dim, self.outputs_vocab_dim
            ]))

        return logits, loss
예제 #25
0
  def mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    tf.logging.info("features = %s" % features)
    hparams = self._hparams
    activation_dtype = self.activation_type

    # We assume fixed vocab size for targets
    targets = tf.to_int32(features["targets"])

    # Image preprocessing, reshape into a 1D sequence and shift right.
    length = hparams.img_len*hparams.img_len*hparams.num_channels
    targets = tf.reshape(targets, [hparams.batch_size, length])
    shifted_targets = common_layers.shift_right_2d(targets)

    # Declare all the dimensions
    batch_dim = mtf.Dimension("batch", hparams.batch_size)

    def import_to_batch_by_length(x, name):
      return mtf.import_tf_tensor(
          mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name)

    targets = import_to_batch_by_length(targets, "targets")
    shifted_targets = import_to_batch_by_length(
        shifted_targets, "shifted_targets")

    extra_losses = []

    # Create targets content and position embeddings.
    # Create embedding var for targets and positions and do a gather.
    targets_embedding_var = mtf.get_variable(
        mesh, "targets_embedding",
        mtf.Shape([self.targets_vocab_dim, self.model_dim]),
        initializer=tf.random_normal_initializer(),
        activation_dtype=activation_dtype)

    x = mtf.gather(targets_embedding_var,
                   shifted_targets, self.targets_vocab_dim)

    # Add positional embeddings
    x += mtf.reshape(self.create_positional_emb_2d(targets),
                     [self.length_dim, self.model_dim])

    # If conditional and input is given, add the input embedding to the target.
    # TODO(nikip): Verify conditional.
    if self.has_input and not hparams.unconditional:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = import_to_batch_by_length(inputs, "inputs")

      # Input embeddings
      inputs_embedding_var = mtf.layers.embedding(
          mesh, "input_embedding",
          mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
          activation_dtype=activation_dtype)
      inputs_emb = mtf.gather(
          inputs_embedding_var, inputs, self.inputs_vocab_dim)
      x += inputs_emb

    # Image Transformer Decoder
    # [ self attention - ffn - residual + dropout] x n
    if hparams.attention_type == "local1d_spatial":
      decoder_output = local_attention1d_spatial_decoder(
          x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
    elif hparams.attention_type == "local2d_spatial":
      decoder_output = local_attention2d_spatial_decoder(
          x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
    elif hparams.attention_type == "local1d":
      decoder_output = local_attention1d_masked_decoder(
          x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
    else:
      raise ValueError("Invalid attention type.")

    # Calculate the logits and loss.
    logits = mtf.layers.dense(
        decoder_output, self.outputs_vocab_dim, name="logits")
    # Need a reshape for logits
    logits = mtf.reshape(
        logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim]))
    soft_targets = mtf.one_hot(
        targets, self.outputs_vocab_dim, dtype=activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.outputs_vocab_dim)
    loss = mtf.reduce_mean(loss)
    for l in extra_losses:
      loss += l

    # Reshape logits to original target shape.
    logits = mtf.reshape(
        logits,
        mtf.Shape([batch_dim, self.rows_dim, self.orig_cols_dim,
                   self.channels_dim, self.outputs_vocab_dim]))

    return logits, loss
예제 #26
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.activation_type

        # We assume fixed vocab size for targets
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim,
                                                   self.length_dim]),
                                        name=name)

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([self.targets_vocab_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       self.targets_vocab_dim)

        # Add positional embeddings
        x += mtf.reshape(self.create_positional_emb_2d(targets),
                         [self.length_dim, self.model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf.layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    self.inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        if hparams.attention_type == "local1d_spatial":
            decoder_output = local_attention1d_spatial_decoder(
                x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
        elif hparams.attention_type == "local2d_spatial":
            decoder_output = local_attention2d_spatial_decoder(
                x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
        elif hparams.attention_type == "local1d":
            decoder_output = local_attention1d_masked_decoder(
                x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
        else:
            raise ValueError("Invalid attention type.")

        # Calculate the logits and loss.
        logits = mtf.layers.dense(decoder_output,
                                  self.outputs_vocab_dim,
                                  name="logits")
        # Need a reshape for logits
        logits = mtf.reshape(
            logits,
            mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim]))
        soft_targets = mtf.one_hot(targets,
                                   self.outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.outputs_vocab_dim)
        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l

        # Reshape logits to original target shape.
        logits = mtf.reshape(
            logits,
            mtf.Shape([
                batch_dim, self.rows_dim, self.orig_cols_dim,
                self.channels_dim, self.outputs_vocab_dim
            ]))

        return logits, loss
예제 #27
0
  def _mtf_model_fn(self, features, mesh):
    self._original_features = features
    features = copy.copy(features)
    hparams = self._hparams
    extra_losses = []
    targets = tf.to_int32(features["targets"])
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    if hparams.decoder_type == "autoregressive":
      shifted_targets = mtf.shift(
          targets, offset=1, dim=self.length_dim, wrap=False)
    elif hparams.decoder_type == "denoising":
      shifted_targets = self._noisy_targets(targets, extra_losses)
    else:
      raise ValueError(
          "unknown hparams.decoder_type = %s" % hparams.decoder_type)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
          targets_segmentation, dtype=self.activation_dtype)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
      else:
        decoder_self_attention_mask = None

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "decoder":
      encoder_output = None
      encoder_decoder_attention_mask = None
    else:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)

    if hparams.transformer_type == "encdec":
      if "inputs_segmentation" in features:
        encoder_decoder_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        encoder_decoder_attention_mask = encoder_self_attention_mask
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)

    if hparams.transformer_type != "encoder":
      # DECODER
      x = (mtf.gather(
          targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
           mtf.gather(
               positional_embedding_var, targets_position, self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("decoder"):
        x = self._layer_stack(
            x,
            hparams.decoder_layers,
            encoder_output=encoder_output,
            self_attention_mask=decoder_self_attention_mask,
            encdec_attention_mask=encoder_decoder_attention_mask,
            losses=extra_losses)
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # For some reason, the logits computation is extremely slow on TPU
      # in some cases where the batch size per core is 1.  Reshape the logits
      # and the targets to double the batch size and halve the length.
      # TODO(noam): file a bug.
      old_dims = self.batch_dims + [self.length_dim]
      new_dims = self.batch_dims[:-1] + [
          mtf.Dimension(self.batch_dims[-1].name,
                        self.batch_dims[-1].size * 2),
          mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
      x = mtf.reshape(x, new_dims + [self.model_dim])
      targets = mtf.reshape(targets, new_dims)

    logits = mtf.matmul(x, softmax_var)
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
    logits = mtf.to_float(logits)
    return logits, loss
예제 #28
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = self.make_params(context)
     if self.share_qk_rep:
         q, k = params.mdha_shared_qk(x, context)
     else:
         q = params.mdha_q(x, context)
     memory_length = self.memory_length(context)
     if context.mode == "incremental":
         m = x
     else:
         if self.share_qk_rep:
             k = mtf.replace_dimensions(k, context.length_dim,
                                        memory_length)
         m = mtf.replace_dimensions(x, context.length_dim, memory_length)
     if self.shared_kv:
         kv = params.compute_kv(m)
     else:
         if not self.share_qk_rep:
             k = params.mdha_k(m, context)
         v = params.mdha_v(m, context)
     if context.mode == "incremental":
         one_hot = mtf.one_hot(context.position,
                               memory_length,
                               dtype=context.activation_dtype)
         inv_one_hot = 1.0 - one_hot
         if self.shared_kv:
             old_kv = context.get_states(1)
             kv = old_kv * inv_one_hot + kv * one_hot
         else:
             old_k, old_v = context.get_states(2)
             k = old_k * inv_one_hot + k * one_hot
             v = old_v * inv_one_hot + v * one_hot
         memory_position = mtf.range(context.mesh, memory_length, tf.int32)
     else:
         memory_position = self.rename_length_to_memory_length(
             context.position, context)
     if context.mode == "incremental" or context.mode == "first_part":
         context.record_new_states([kv] if self.shared_kv else [k, v])
     if self.shared_kv:
         k = kv
         v = kv
     o = self.attention_fn(q,
                           k,
                           v,
                           context=context,
                           memory_length_dim=memory_length,
                           key_dim=self.kv_dim,
                           value_dim=self.kv_dim,
                           bias=self.compute_bias(context, memory_position,
                                                  x,
                                                  params.query_heads_dims,
                                                  q),
                           **self.attention_kwargs_from_context(context))
     attention_output_shape = self.expected_attention_output_shape(
         x, params)
     attention_output = params.compute_output(
         o, output_shape=attention_output_shape)
     return self.layer_output_from_attention_output(context,
                                                    attention_output,
                                                    losses)
예제 #29
0
  def mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    tf.logging.info("features = %s" % features)
    hparams = self._hparams
    activation_dtype = self.set_activation_type()
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

    # Declare all the dimensions
    batch_dim = mtf.Dimension("batch", hparams.batch_size)
    hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
    filter_h_dim = mtf.Dimension("filter_height", 7)
    filter_w_dim = mtf.Dimension("filter_width", 7)
    filters = mtf.Dimension("filters", hparams.filter_sizes[0])
    rows_dim = mtf.Dimension("rows_size", hparams.rows_size)
    cols_dim = mtf.Dimension("cols_size", hparams.cols_size)
    row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
    col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
    classes_dim = mtf.Dimension("classes", 10)
    channels_dim = mtf.Dimension("channels", 3)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    inputs = features["inputs"]
    x = mtf.import_tf_tensor(
        mesh, tf.reshape(inputs, [
            hparams.batch_size,
            hparams.row_blocks,
            hparams.rows_size // hparams.row_blocks,
            hparams.col_blocks,
            hparams.num_channels*hparams.cols_size // hparams.col_blocks,
            hparams.num_channels]),
        mtf.Shape(
            [batch_dim, row_blocks_dim, rows_dim,
             col_blocks_dim, cols_dim, channels_dim]))
    x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim,
                          rows_dim, cols_dim, channels_dim])

    x = mtf.to_float(x)
    initial_filters = mtf.get_variable(
        mesh, "init_filters",
        mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters]))
    x = mtf.conv2d_with_blocks(
        x,
        initial_filters,
        strides=[1, 1, 1, 1],
        padding="SAME",
        h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

    x = batch_norm_relu(x, is_training)

    # Conv blocks
    # [block - strided block layer - strided block layer] x n
    for layer in range(hparams.num_layers):
      layer_name = "block_layer_%d" % layer
      with tf.variable_scope(layer_name):
        # Residual block layer
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[0],
            blocks=hparams.layer_sizes[0],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer1",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[1],
            blocks=hparams.layer_sizes[1],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer2",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[2],
            blocks=hparams.layer_sizes[2],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer3",
            row_blocks_dim=None,
            col_blocks_dim=None)

    # Calculate the logits and loss.
    out = x
    outputs = mtf.layers.dense(
        out, hidden_dim,
        reduced_dims=out.shape.dims[-5:],
        activation=mtf.relu, name="dense")

    # We assume fixed vocab size for targets
    labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim]))

    logits = mtf.layers.dense(outputs, classes_dim, name="logits")
    soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, classes_dim)

    # Reshape logits so it doesn't break inside t2t.
    logits = mtf.reshape(
        logits,
        mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
    loss = mtf.reduce_mean(loss)
    return logits, loss
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows_size", 28)
    cols_dim = mtf.Dimension("cols_size", 28)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28, 1]),
        mtf.Shape([batch_dim, rows_dim, cols_dim, one_channel_dim]))

    fh_dim = mtf.Dimension("fh", 3)
    fw_dim = mtf.Dimension("fw", 3)
    filters1_dim = mtf.Dimension("filters1", FLAGS.num_filters)
    filters2_dim = mtf.Dimension("filters2", FLAGS.num_filters)
    filters3_dim = mtf.Dimension("filters3", FLAGS.num_filters)
    filters4_dim = mtf.Dimension("filters4", FLAGS.num_filters)
    filters5_dim = mtf.Dimension("filters5", FLAGS.num_filters)
    filters6_dim = mtf.Dimension("filters6", FLAGS.num_filters)

    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])
    kernel3 = mtf.get_variable(mesh, "kernel3",
                               [fh_dim, fw_dim, filters2_dim, filters3_dim])
    kernel4 = mtf.get_variable(mesh, "kernel4",
                               [fh_dim, fw_dim, filters3_dim, filters4_dim])
    kernel5 = mtf.get_variable(mesh, "kernel5",
                               [fh_dim, fw_dim, filters4_dim, filters5_dim])
    kernel6 = mtf.get_variable(mesh, "kernel6",
                               [fh_dim, fw_dim, filters5_dim, filters6_dim])

    x = mtf.relu(mtf.conv2d(x, kernel1, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel2, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel3, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel4, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel5, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel6, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.reduce_mean(x, reduced_dim=filters6_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)
    logits = mtf.Dimension("logits", 10)
    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-2:],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
예제 #31
0
  def _mtf_model_fn(self, features, mesh):
    self._original_features = features
    features = copy.copy(features)
    hparams = self._hparams
    extra_losses = []
    targets = tf.to_int32(features["targets"])
    mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    if hparams.decoder_type == "autoregressive":
      shifted_targets = mtf.shift(
          targets, offset=1, dim=self.length_dim, wrap=False)
    elif hparams.decoder_type == "denoising":
      shifted_targets = self._noisy_targets(targets, extra_losses)
    else:
      raise ValueError(
          "unknown hparams.decoder_type = %s" % hparams.decoder_type)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
          targets_segmentation, dtype=self.activation_dtype)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
      else:
        decoder_self_attention_mask = None

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "decoder":
      encoder_output = None
      encoder_decoder_attention_mask = None
    else:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)

    if hparams.transformer_type == "encdec":
      if "inputs_segmentation" in features:
        encoder_decoder_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        encoder_decoder_attention_mask = encoder_self_attention_mask
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)

    if hparams.transformer_type != "encoder":
      # DECODER
      x = (mtf.gather(
          targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
           mtf.gather(
               positional_embedding_var, targets_position, self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("decoder"):
        x = self._layer_stack(
            x,
            hparams.decoder_layers,
            encoder_output=encoder_output,
            self_attention_mask=decoder_self_attention_mask,
            encdec_attention_mask=encoder_decoder_attention_mask,
            losses=extra_losses)
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # For some reason, the logits computation is extremely slow on TPU
      # in some cases where the batch size per core is 1.  Reshape the logits
      # and the targets to double the batch size and halve the length.
      # TODO(noam): file a bug.
      old_dims = self.batch_dims + [self.length_dim]
      new_dims = self.batch_dims[:-1] + [
          mtf.Dimension(self.batch_dims[-1].name,
                        self.batch_dims[-1].size * 2),
          mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
      x = mtf.reshape(x, new_dims + [self.model_dim])
      targets = mtf.reshape(targets, new_dims)

    logits = mtf.matmul(x, softmax_var)
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
    logits = mtf.to_float(logits)
    return logits, loss
예제 #32
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        filter_h_dim = mtf.Dimension("filter_height", 7)
        filter_w_dim = mtf.Dimension("filter_width", 7)
        filters = mtf.Dimension("filters", hparams.filter_sizes[0])
        rows_dim = mtf.Dimension("rows_size", 32)
        cols_dim = mtf.Dimension("cols_size", 96)
        row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
        col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
        classes_dim = mtf.Dimension("classes", 10)
        one_channel_dim = mtf.Dimension("one_channel", 1)

        inputs = features["inputs"]
        x = mtf.import_tf_tensor(
            mesh,
            tf.reshape(inputs, [
                hparams.batch_size, hparams.row_blocks,
                hparams.rows_size // hparams.row_blocks, hparams.col_blocks,
                hparams.num_channels * hparams.cols_size // hparams.col_blocks,
                1
            ]),
            mtf.Shape([
                batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
                one_channel_dim
            ]))
        x = mtf.transpose(x, [
            batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
            one_channel_dim
        ])

        x = mtf.to_float(x)
        initial_filters = mtf.get_variable(
            mesh, "init_filters",
            mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters]))
        x = mtf.conv2d_with_blocks(x,
                                   initial_filters,
                                   strides=[1, 1, 1, 1],
                                   padding="SAME",
                                   h_blocks_dim=None,
                                   w_blocks_dim=col_blocks_dim)

        x = batch_norm_relu(x, is_training)

        # Conv blocks
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_layers):
            layer_name = "block_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Residual block layer
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[0],
                                blocks=hparams.layer_sizes[0],
                                strides=[1, 1, 1, 1],
                                is_training=is_training,
                                name="block_layer1",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[1],
                                blocks=hparams.layer_sizes[1],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer2",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[2],
                                blocks=hparams.layer_sizes[2],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer3",
                                row_blocks_dim=None,
                                col_blocks_dim=None)

        # Calculate the logits and loss.
        out = x
        outputs = mtf.layers.dense(out,
                                   hidden_dim,
                                   reduced_dims=out.shape.dims[-5:],
                                   activation=mtf.relu,
                                   name="dense")

        # We assume fixed vocab size for targets
        labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [hparams.batch_size]),
                                      mtf.Shape([batch_dim]))

        logits = mtf.layers.dense(outputs, classes_dim, name="logits")
        soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, classes_dim)

        # Reshape logits so it doesn't break inside t2t.
        logits = mtf.reshape(
            logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
        loss = mtf.reduce_mean(loss)
        return logits, loss
예제 #33
0
def mnist_model(image, labels, mesh):
	"""The model.
	Args:
		image: tf.Tensor with shape [batch, 28*28]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
	batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
	row_blocks_dim = mtf.Dimension("row_blocks", 4)
	col_blocks_dim = mtf.Dimension("col_blocks", 4)
	rows_dim = mtf.Dimension("rows_size", 7)
	cols_dim = mtf.Dimension("cols_size", 7)

	classes_dim = mtf.Dimension("classes", 10)
	one_channel_dim = mtf.Dimension("one_channel", 1)

	x = mtf.import_tf_tensor(
		mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
		mtf.Shape(
			[batch_dim, row_blocks_dim, rows_dim,
			col_blocks_dim, cols_dim, one_channel_dim]))
	x = mtf.transpose(x, [
		batch_dim, row_blocks_dim, col_blocks_dim,
		rows_dim, cols_dim, one_channel_dim])
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape))
	# add some convolutional layers to demonstrate that convolution works.
	filters1_dim = mtf.Dimension("filters1", 16)
	filters2_dim = mtf.Dimension("filters2", 16)
	f1 = mtf.relu(mtf.layers.conv2d_with_blocks(
		x, filters1_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME",
		h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0"))
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(f1.name,f1.shape))
	f2 = mtf.relu(mtf.layers.conv2d_with_blocks(
		f1, filters2_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME",
		h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1"))
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(f2.name,f2.shape))
	x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape))
	# add some fully-connected dense layers.
	hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
	hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

	h1 = mtf.layers.dense(
		x, hidden_dim1,
		reduced_dims=x.shape.dims[-4:],
		activation=mtf.relu, name="hidden1")
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(h1.name,h1.shape))
	h2 = mtf.layers.dense(
		h1, hidden_dim2,
		activation=mtf.relu, name="hidden2")
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(h2.name,h2.shape))
	logits = mtf.layers.dense(h2, classes_dim, name="logits")
	if labels is None:
		loss = None
	else:
		labels = mtf.import_tf_tensor(
			mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
		loss = mtf.layers.softmax_cross_entropy_with_logits(
			logits, mtf.one_hot(labels, classes_dim), classes_dim)
		loss = mtf.reduce_mean(loss)
	return logits, loss