Example #1
0
    def call(self, context, x, run_layers_range=None):
        """Call the layer stack."""
        tf.logging.info("Calling Charformer layer stack")
        x = self._call_sublayers(self._sublayers_initial, x, context)
        context.layer_outputs.append(x)

        if run_layers_range:
            layers = self._layers[run_layers_range[0]:run_layers_range[1]]
        else:
            layers = self._layers

        for lnum, layer in enumerate(layers):
            tf.logging.info("Running=%d | %s", lnum, layer.__class__.__name__)
            tf.logging.info(layer)
            with tf.variable_scope(layer.name or ""):
                if self._recompute_grads:

                    def fn(x, l=layer, c=context):
                        return self._layer_fn(x, l, c)

                    x = mtf.recompute_grad(fn, [x])
                else:
                    x = self._layer_fn(x, layer, context)
            if lnum != len(self._layers) - 1:
                context.layer_outputs.append(x)
            context.layer_index += 1
        x = self._call_sublayers(self._sublayers_final, x, context)
        x = sublayer_mask_padding(x, self, context)
        context.layer_outputs.append(x)
        self.context = context
        return x
Example #2
0
    def call(self, context, x):
        """Call the layer stack."""
        x = self._call_sublayers(self._sublayers_initial, x, context)
        context.layer_outputs.append(x)

        assert context.layer_index == 0

        for block_idx in range(self.n_blocks):
            for param_idx in range(self.block_param_size[block_idx]):
                # Number of layers to (locally) share parameters.
                cur_repeat_size = self.block_repeat_size[block_idx]
                for repeat_idx in range(cur_repeat_size):
                    # context.do_pooling = block_idx > 0 and sub_idx == 0

                    # Submodules are transformer.TransformerLayer objects such as
                    # SelfAttention and DenseReluDense.
                    for submodule_idx in range(self.n_submodules):
                        layer = self._layers[context.layer_index]
                        name = (f"funnel_block_{block_idx:03d}/"
                                f"param_idx_{param_idx:03d}/"
                                f"submodule_{submodule_idx:03d}")
                        # Override the layer name given in transformer.make_layer_stack.
                        layer.set_name(name)

                        with tf.variable_scope(layer.name or ""):
                            x = self._layer_fn(x, layer, context)

                        # Do pooling if the current layer
                        # 1) does not belong to the first block
                        # 2) is the first layer within the current block
                        # 3) is the first submodule (typically SelfAttention).
                        sub_idx = (param_idx * cur_repeat_size + repeat_idx)
                        if block_idx > 0 and sub_idx == 0 and submodule_idx == 0:
                            x = mtf.pool_tensor_1d(x,
                                                   pool_dim=context.length_dim,
                                                   reduce_fn=self.pool_fn,
                                                   pool_size=self.pooling_size)
                            self.update_context(context,
                                                x,
                                                pool_dim_name="length")

                        if context.layer_index != len(self._layers) - 1:
                            context.layer_outputs.append(x)
                        context.layer_index += 1

        x = self._call_sublayers(self._sublayers_final, x, context)
        x = transformer.sublayer_mask_padding(x, self, context)
        context.layer_outputs.append(x)
        self.set_context(context)
        return x
Example #3
0
    def call(self, context, x):
        """Call the layer stack."""
        x = self._call_sublayers(self._sublayers_initial, x, context, 0)
        context.layer_outputs.append(x)
        for lnum, layer in enumerate(self._layers):
            with tf.variable_scope(layer.name or ""):
                if self._recompute_grads:

                    def fn(x, l=layer, c=context, lnum_arg=lnum):
                        return self._layer_fn(x, l, c, lnum_arg)

                    x = mtf.recompute_grad(fn, [x])
                else:
                    x = self._layer_fn(x, layer, context, lnum)
            if lnum != len(self._layers) - 1:
                context.layer_outputs.append(x)
            context.layer_index += 1
        x = self._call_sublayers(self._sublayers_final, x, context, 0)
        x = transformer.sublayer_mask_padding(x, self, context)
        context.layer_outputs.append(x)
        return x
Example #4
0
def sublayer_call_layer(x, layer_stack, context):
    x = sublayer_mask_padding(x, layer_stack, context)
    layer = context.current_layer
    with tf.variable_scope(layer.__class__.__name__):
        return layer.call(context, x)