Exemplo n.º 1
0
 def forward(self, x, bias=None):
     x, gate = x.chunk(2, dim=-1)
     if bias is not None:
         bias_1, bias_2 = bias.chunk(2, dim=-1)
         x = x + bias_1
     else:
         bias_1 = bias_2 = 0
     if self.bias_gelu_fusion:
         intermediate_parallel = \
             bias_gelu_impl(gate, bias_2)
     else:
         intermediate_parallel = \
             self.activation_func(gate + bias_2)
     return intermediate_parallel * x
Exemplo n.º 2
0
    def forward(self, hidden_states):

        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
Exemplo n.º 3
0
    def forward(self, hidden_states):

        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

        if self.activation_type == "gelu":
            if self.bias_gelu_fusion:
                intermediate_parallel = \
                    bias_gelu_impl(intermediate_parallel, bias_parallel)
            else:
                intermediate_parallel = \
                    self.activation_func(intermediate_parallel + bias_parallel)
        elif self.activation_type == "geglu":
            intermediate_parallel = \
                self.activation_func(intermediate_parallel)
        else:
            raise ValueError(f'Activation type {self.activation_type} not recognized')

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias