Exemple #1
0
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
                for index in range(start, end):
                    layer = self._get_layer(index)
                    x_ = layer(x_, attention_mask, encoder_output,
                               enc_dec_attn_mask)
                return x_

            return custom_forward

        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
        l = 0
        while l < self.num_layers:
            hidden_states = mpu.checkpoint(
                custom(l, l + self.checkpoint_num_layers), hidden_states,
                attention_mask, encoder_output, enc_dec_attn_mask)
            l += self.checkpoint_num_layers

        return hidden_states
Exemple #2
0
    def _checkpointed_forward(self, hidden_states, attention_mask):
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
                for index in range(start, end):
                    layer = self._get_layer(index)
                    x_ = layer(x_, inputs[1])
                return x_
            return custom_forward

        l = 0
        while l < self.num_layers:
            hidden_states = mpu.checkpoint(
                custom(l, l + self.checkpoint_num_layers),
                hidden_states, attention_mask)
            l += self.checkpoint_num_layers

        return hidden_states
    def _checkpointed_forward(self, hidden_states, attention_mask):
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                layers_ = self.layers[start:end]
                x_ = inputs[0]
                for layer in layers_:
                    x_ = layer(x_, inputs[1])
                return x_

            return custom_forward

        l = 0
        num_layers = len(self.layers)
        while l < num_layers:
            hidden_states = mpu.checkpoint(
                custom(l, l + self.checkpoint_num_layers), hidden_states,
                attention_mask)
            l += self.checkpoint_num_layers

        return hidden_states