Exemplo n.º 1
0
    def testRecomputeGrad(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        # let's differentiate x^2 + x
        # dy/dx = 2x+1
        def x_squared_plus_x(x):
            return x * x + x

        x = tf.constant([5, 10], dtype=tf.float32)
        dy = tf.constant([2, 3], dtype=tf.float32)
        two = mtf.Dimension("two", 2)
        expected_y = tf.constant([30, 110], dtype=tf.float32)
        expected_dx = tf.constant([22, 63], dtype=tf.float32)
        mtf_x = mtf.import_fully_replicated(mesh, x, shape=mtf.Shape([two]))
        mtf_dy = mtf.import_tf_tensor(mesh, dy, shape=mtf.Shape([two]))
        mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x])
        [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy])
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            shape="processors:2", layout="two:processors", devices=["", ""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_y = lowering.export_to_tf_tensor(mtf_y)
        actual_dx = lowering.export_to_tf_tensor(mtf_dx)
        self.assertAllEqual(self.evaluate(actual_y), self.evaluate(expected_y))
        self.assertAllEqual(self.evaluate(actual_dx),
                            self.evaluate(expected_dx))
Exemplo n.º 2
0
    def call(self, context, x, run_layers_range=None):
        """Call the layer stack."""
        tf.logging.info("Calling Charformer layer stack")
        x = self._call_sublayers(self._sublayers_initial, x, context)
        context.layer_outputs.append(x)

        if run_layers_range:
            layers = self._layers[run_layers_range[0]:run_layers_range[1]]
        else:
            layers = self._layers

        for lnum, layer in enumerate(layers):
            tf.logging.info("Running=%d | %s", lnum, layer.__class__.__name__)
            tf.logging.info(layer)
            with tf.variable_scope(layer.name or ""):
                if self._recompute_grads:

                    def fn(x, l=layer, c=context):
                        return self._layer_fn(x, l, c)

                    x = mtf.recompute_grad(fn, [x])
                else:
                    x = self._layer_fn(x, layer, context)
            if lnum != len(self._layers) - 1:
                context.layer_outputs.append(x)
            context.layer_index += 1
        x = self._call_sublayers(self._sublayers_final, x, context)
        x = sublayer_mask_padding(x, self, context)
        context.layer_outputs.append(x)
        self.context = context
        return x
Exemplo n.º 3
0
 def encoder(self, x):
     with tf.variable_scope("encoder"):
         for n in range(self.num_layers):
             is_last = n == self.num_layers - 1
             block_fn = self.encoder_block(n, is_last)
             if self.params.get("recompute_grad", False) and (self.mode
                                                              == "train"):
                 x = mtf.recompute_grad(block_fn, [x])
             else:
                 x = block_fn(x)
         return x
Exemplo n.º 4
0
 def transformer(self, x, mask):
     for layer in range(self.n_layers):
         # attn blocks
         block_fn = self.block(mask, f"layer_{layer}")
         # If true and in train mode, enable gradient checkpointing
         if self.params.get("recompute_grad", False) and (self.mode
                                                          == "train"):
             x = mtf.recompute_grad(block_fn, [x])
         else:
             x = block_fn(x)
     return x
Exemplo n.º 5
0
    def call(self, context, x):
        """Call the layer stack."""
        x = self._call_sublayers(self._sublayers_initial, x, context, 0)
        context.layer_outputs.append(x)
        for lnum, layer in enumerate(self._layers):
            with tf.variable_scope(layer.name or ""):
                if self._recompute_grads:

                    def fn(x, l=layer, c=context, lnum_arg=lnum):
                        return self._layer_fn(x, l, c, lnum_arg)

                    x = mtf.recompute_grad(fn, [x])
                else:
                    x = self._layer_fn(x, layer, context, lnum)
            if lnum != len(self._layers) - 1:
                context.layer_outputs.append(x)
            context.layer_index += 1
        x = self._call_sublayers(self._sublayers_final, x, context, 0)
        x = transformer.sublayer_mask_padding(x, self, context)
        context.layer_outputs.append(x)
        return x
Exemplo n.º 6
0
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None):
    """A GPT style model implemented in mesh tensorflow."""

    x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features)

    if is_incremental_inference(context):
        # reshape inputs if in inference mode
        x = mtf.gather(x, context.position - 1, sequence_dim)
        x = mtf.reshape(x, [batch_dim])

    use_axial_pos_emb = params["axial_pos_emb"] is not None

    if not use_axial_pos_emb:
        # Use standard position encoding
        wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]),
                               initializer=tf.random_normal_initializer(stddev=0.01),
                               master_dtype=variable_dtype.master_dtype,
                               slice_dtype=variable_dtype.slice_dtype,
                               activation_dtype=variable_dtype.activation_dtype)
    else:
        wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)

    # Text encoding
    wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]),
                           initializer=tf.random_normal_initializer(stddev=0.02),
                           master_dtype=variable_dtype.master_dtype,
                           slice_dtype=variable_dtype.slice_dtype,
                           activation_dtype=variable_dtype.activation_dtype)

    with tf.variable_scope("token_embd"):
        # Text embedding
        h = mtf.gather(wte, x, vocab_dim)
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout")

    with tf.variable_scope("pos_embd"):
        # Positional embedding
        position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else (
                context.position - 1)
        pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout")
        h += pos_emb

    aux_losses = 0  # instantiate auxiliary losses (for MOE models)

    for layer in range(params["n_layer"]):
        # attn blocks
        share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True
        block_scope = f"h{layer}" if not share_parameters else ""

        block_fn = block(params=params, scope=block_scope, layer_num=layer,
                         bias=other_features["attn_bias"],
                         sequence_dim=sequence_dim,
                         memory_length_dim=other_features["memory_length_dim"],
                         variable_dtype=variable_dtype,
                         context=context)

        # If true and in train mode, enable gradient checkpointing
        recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True
        h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h])
        aux_losses += loss

    no_weight_tie_emb = params["no_weight_tie"] == True
    if no_weight_tie_emb:
        with tf.variable_scope("wte_final_linear"):
            logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params)
    else:
        # Layer normalize & affine transform
        h = layer_norm(h, "ln_f", variable_dtype=variable_dtype)
        seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1)
        with tf.variable_scope("wte_final_einsum"):
            # Equivalent to tf.matmul
            logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim])

    if params["mode"] in ["train", "eval"]:
        labels = mtf_features["labels"]
        z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy

        # Go to full precision for the logits 
        logits = mtf.cast(logits, tf.float32)

        use_entmax_loss = params.get("entmax_loss", False)
        loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits

        with tf.variable_scope("xentropy_final"):
            loss_batch = loss_fn(logits=logits, targets=labels,
                                 vocab_dim=logits.shape[-1], z_loss=z_loss)

        # For non-autoregressive models (masked language modeling training)
        # Make sure labels with padding tokens are not counted in the loss
        if not params["causal"]:
            padding_id = params.get("padding_id", 0)
            loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch))

        with tf.variable_scope("reduce_mean_final"):
            loss = mtf.reduce_mean(loss_batch)

        loss += aux_losses  # Add on auxiliary losses (currently only used for MoE)
        loss /= params["num_microbatches"]
        # Convert to train dtype
        loss = mtf.cast(loss, variable_dtype.slice_dtype)
    else:
        loss = None
        loss_batch = None

    # Cast back to checkpoint dtype
    logits = mtf.cast(logits, variable_dtype.master_dtype)
    return logits, loss, loss_batch