def forward(self, input_ids: torch.LongTensor, attention_mask=None, token_type_ids=None, token_span=None, **kwargs): if self.word_dropout: input_ids = self.word_dropout(input_ids) x = transformer_encode( self.transformer, input_ids, attention_mask, token_type_ids, token_span, layer_range=self.scalar_mix.mixture_range if self.scalar_mix else 0, max_sequence_length=self.max_sequence_length, average_subwords=self.average_subwords, ret_raw_hidden_states=self.ret_raw_hidden_states) if self.ret_raw_hidden_states: x, raw_hidden_states = x if self.scalar_mix: x = self.scalar_mix(x) if self.ret_raw_hidden_states: # noinspection PyUnboundLocalVariable return x, raw_hidden_states return x
def run_transformer(self, input_ids, token_span): return transformer_encode( self.transformer, input_ids, None, None, token_span, average_subwords=self.config.average_subwords)