def forward(self, x, bias=None): x, gate = x.chunk(2, dim=-1) if bias is not None: bias_1, bias_2 = bias.chunk(2, dim=-1) x = x + bias_1 else: bias_1 = bias_2 = 0 if self.bias_gelu_fusion: intermediate_parallel = \ bias_gelu_impl(gate, bias_2) else: intermediate_parallel = \ self.activation_func(gate + bias_2) return intermediate_parallel * x
def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) if self.bias_gelu_fusion: intermediate_parallel = \ bias_gelu_impl(intermediate_parallel, bias_parallel) else: intermediate_parallel = \ self.activation_func(intermediate_parallel + bias_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias
def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) if self.activation_type == "gelu": if self.bias_gelu_fusion: intermediate_parallel = \ bias_gelu_impl(intermediate_parallel, bias_parallel) else: intermediate_parallel = \ self.activation_func(intermediate_parallel + bias_parallel) elif self.activation_type == "geglu": intermediate_parallel = \ self.activation_func(intermediate_parallel) else: raise ValueError(f'Activation type {self.activation_type} not recognized') # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias