def forward(self, model_input, attention_mask, tokentype_ids=None): extended_attention_mask = bert_extended_attention_mask(attention_mask) kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = model_input position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage(): _, pooled_output = lm_output classification_output = self.classification_dropout(pooled_output) classification_logits = self.classification_head( classification_output) # Reshape back to separate choices. classification_logits = classification_logits.view( -1, self.num_classes) return classification_logits return lm_output
def forward(self, input_ids, attention_mask, tokentype_ids): # [batch, choices, sequence] --> [batch * choices, sequence] --> # transformer --> [batch, choices] --> softmax # Ensure the shape is [batch-size, choices, sequence] assert len(input_ids.shape) == 3 assert len(attention_mask.shape) == 3 assert len(tokentype_ids.shape) == 3 # Reshape and treat choice dimension the same as batch. num_choices = input_ids.shape[1] input_ids = input_ids.view(-1, input_ids.size(-1)) attention_mask = attention_mask.view(-1, attention_mask.size(-1)) tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) _, pooled_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) # Output. multichoice_output = self.multichoice_dropout(pooled_output) multichoice_logits = self.multichoice_head(multichoice_output) # Reshape back to separate choices. multichoice_logits = multichoice_logits.view(-1, num_choices) return multichoice_logits
def forward_step(data_iterator, model, input_tensor): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() tokens, loss_mask, lm_labels, padding_mask, attention_mask, position_ids \ = get_batch(data_iterator) timers('batch-generator').stop() extended_attention_mask = bert_extended_attention_mask( padding_mask) + attention_mask # Forward pass through the model. if mpu.is_pipeline_first_stage(): assert input_tensor is None if mpu.is_pipeline_last_stage(): output_tensor = model(tokens, extended_attention_mask, tokentype_ids=None, lm_labels=lm_labels, position_ids=position_ids) else: output_tensor = model(tokens, extended_attention_mask, tokentype_ids=None) elif mpu.is_pipeline_last_stage(): assert input_tensor is not None output_tensor = model(input_tensor, extended_attention_mask, lm_labels=lm_labels) else: assert input_tensor is not None output_tensor = model(input_tensor, extended_attention_mask, position_ids=position_ids) if mpu.is_pipeline_last_stage(): lm_loss_, _ = output_tensor lm_loss_ = lm_loss_.float() loss_mask = loss_mask.float() lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss = lm_loss averaged_losses = average_losses_across_data_parallel_group([ lm_loss, ]) return loss, {'lm loss': averaged_losses[0]} return output_tensor
def forward(self, input_ids, attention_mask, token_type_ids): extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) sequence_output = self.language_model( input_ids=input_ids, position_ids=position_ids, attention_mask=extended_attention_mask, tokentype_ids=token_type_ids, ) return sequence_output
def forward(self, input_ids, attention_mask, token_type_ids): extended_attention_mask = bert_extended_attention_mask( attention_mask, next(self.language_model.parameters()).dtype) position_ids = bert_position_ids(input_ids) sequence_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=token_type_ids) return sequence_output
def forward(self, input_ids, attention_mask, token_type_ids): if self._lazy_init_fn is not None: self._lazy_init_fn() self._lazy_init_fn = None extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) sequence_output = self.language_model( input_ids=input_ids, position_ids=position_ids, attention_mask=extended_attention_mask, tokentype_ids=token_type_ids, ) return sequence_output
def forward(self, input_ids, attention_mask, tokentype_ids=None): extended_attention_mask = bert_extended_attention_mask( attention_mask, next(self.language_model.parameters()).dtype) position_ids = bert_position_ids(input_ids) lm_output, pooled_output = self.language_model( input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) # Output. ict_logits = self.ict_head(pooled_output) return ict_logits, None
def forward(self, input_ids, attention_mask, token_type_ids): app_state = AppState() if app_state.model_parallel_size is None: self.complete_lazy_init() extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) sequence_output = self.language_model( input_ids=input_ids, position_ids=position_ids, attention_mask=extended_attention_mask, tokentype_ids=token_type_ids, ) return sequence_output
def forward(self, input_ids, attention_mask, tokentype_ids): extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) _, pooled_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) # Output. classification_output = self.classification_dropout(pooled_output) classification_logits = self.classification_head(classification_output) # Reshape back to separate choices. classification_logits = classification_logits.view(-1, self.num_classes) return classification_logits
def forward(self, model_input, attention_mask, tokentype_ids=None): # [batch, choices, sequence] --> [batch * choices, sequence] --> # transformer --> [batch, choices] --> softmax # Ensure the shape is [batch-size, choices, sequence] assert len(attention_mask.shape) == 3 num_choices = attention_mask.shape[1] # Reshape and treat choice dimension the same as batch. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) extended_attention_mask = bert_extended_attention_mask(attention_mask) kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = model_input # Do the same as attention_mask for input_ids, tokentype_ids assert len(input_ids.shape) == 3 assert len(tokentype_ids.shape) == 3 input_ids = input_ids.view(-1, input_ids.size(-1)) tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage(): _, pooled_output = lm_output multichoice_output = self.multichoice_dropout(pooled_output) multichoice_logits = self.multichoice_head(multichoice_output) # Reshape back to separate choices. multichoice_logits = multichoice_logits.view(-1, num_choices) return multichoice_logits return lm_output