def forward(self, model_input, attention_mask, tokentype_ids=None): extended_attention_mask = bert_extended_attention_mask(attention_mask) kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = model_input position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage(): _, pooled_output = lm_output classification_output = self.classification_dropout(pooled_output) classification_logits = self.classification_head( classification_output) # Reshape back to separate choices. classification_logits = classification_logits.view( -1, self.num_classes) return classification_logits return lm_output
def forward(self, input_ids, attention_mask, tokentype_ids): # [batch, choices, sequence] --> [batch * choices, sequence] --> # transformer --> [batch, choices] --> softmax # Ensure the shape is [batch-size, choices, sequence] assert len(input_ids.shape) == 3 assert len(attention_mask.shape) == 3 assert len(tokentype_ids.shape) == 3 # Reshape and treat choice dimension the same as batch. num_choices = input_ids.shape[1] input_ids = input_ids.view(-1, input_ids.size(-1)) attention_mask = attention_mask.view(-1, attention_mask.size(-1)) tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) _, pooled_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) # Output. multichoice_output = self.multichoice_dropout(pooled_output) multichoice_logits = self.multichoice_head(multichoice_output) # Reshape back to separate choices. multichoice_logits = multichoice_logits.view(-1, num_choices) return multichoice_logits
def forward(self, input_ids, attention_mask, token_type_ids): extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) sequence_output = self.language_model( input_ids=input_ids, position_ids=position_ids, attention_mask=extended_attention_mask, tokentype_ids=token_type_ids, ) return sequence_output
def forward(self, input_ids, attention_mask, token_type_ids): extended_attention_mask = bert_extended_attention_mask( attention_mask, next(self.language_model.parameters()).dtype) position_ids = bert_position_ids(input_ids) sequence_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=token_type_ids) return sequence_output
def forward(self, input_ids, attention_mask, token_type_ids): if self._lazy_init_fn is not None: self._lazy_init_fn() self._lazy_init_fn = None extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) sequence_output = self.language_model( input_ids=input_ids, position_ids=position_ids, attention_mask=extended_attention_mask, tokentype_ids=token_type_ids, ) return sequence_output
def forward(self, input_ids, attention_mask, tokentype_ids=None): extended_attention_mask = bert_extended_attention_mask( attention_mask, next(self.language_model.parameters()).dtype) position_ids = bert_position_ids(input_ids) lm_output, pooled_output = self.language_model( input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) # Output. ict_logits = self.ict_head(pooled_output) return ict_logits, None
def forward(self, input_ids, attention_mask, token_type_ids): app_state = AppState() if app_state.model_parallel_size is None: self.complete_lazy_init() extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) sequence_output = self.language_model( input_ids=input_ids, position_ids=position_ids, attention_mask=extended_attention_mask, tokentype_ids=token_type_ids, ) return sequence_output
def forward(self, input_ids, attention_mask, tokentype_ids): extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) _, pooled_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) # Output. classification_output = self.classification_dropout(pooled_output) classification_logits = self.classification_head(classification_output) # Reshape back to separate choices. classification_logits = classification_logits.view(-1, self.num_classes) return classification_logits
def forward(self, model_input, attention_mask, tokentype_ids=None): # [batch, choices, sequence] --> [batch * choices, sequence] --> # transformer --> [batch, choices] --> softmax # Ensure the shape is [batch-size, choices, sequence] assert len(attention_mask.shape) == 3 num_choices = attention_mask.shape[1] # Reshape and treat choice dimension the same as batch. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) extended_attention_mask = bert_extended_attention_mask(attention_mask) kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = model_input # Do the same as attention_mask for input_ids, tokentype_ids assert len(input_ids.shape) == 3 assert len(tokentype_ids.shape) == 3 input_ids = input_ids.view(-1, input_ids.size(-1)) tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage(): _, pooled_output = lm_output multichoice_output = self.multichoice_dropout(pooled_output) multichoice_logits = self.multichoice_head(multichoice_output) # Reshape back to separate choices. multichoice_logits = multichoice_logits.view(-1, num_choices) return multichoice_logits return lm_output
def forward(self, input_ids, attention_mask, tokentype_ids=None): extended_attention_mask = attention_mask.unsqueeze(1) #extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) lm_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) # This mask will be used in average-pooling and max-pooling pool_mask = (input_ids == self.pad_id).unsqueeze(2) # Taking the representation of the [CLS] token of BERT pooled_output = lm_output[:, 0, :] # Converting to float16 dtype pooled_output = pooled_output.to(lm_output.dtype) # Output. if self.biencoder_projection_dim: pooled_output = self.projection_enc(pooled_output) return pooled_output