def post_language_model_processing(lm_output, labels, logit_weights, get_key_value, parallel_output, forward_method_parallel_output, fp16_lm_cross_entropy): if get_key_value: lm_output, presents = lm_output # Output. if forward_method_parallel_output is not None: parallel_output = forward_method_parallel_output output = parallel_lm_logits( lm_output, logit_weights, parallel_output) if get_key_value: output = [output, presents] if labels is None: return output else: if fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss
def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask, tokentype_ids=None, lm_labels=None, enc_hidden_states=None): # Converting the attention masks to proper parameter settings encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask( [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]) encoder_position_ids = t5_position_ids(encoder_input_ids) decoder_position_ids = t5_position_ids(decoder_input_ids) lm_output = self.language_model(encoder_input_ids, encoder_position_ids, encoder_attn_mask, decoder_input_ids, decoder_position_ids, decoder_attn_mask, encoder_decoder_attn_mask, tokentype_ids=tokentype_ids, enc_hidden_states=enc_hidden_states) if self.post_process and self.add_decoder: decoder_output, encoder_output = lm_output # Output. lm_logits = self.lm_head(decoder_output, self.word_embeddings_weight()) if lm_labels is None: return lm_logits else: if self.fp16_lm_cross_entropy: assert lm_logits.dtype == torch.half lm_loss = mpu.vocab_parallel_cross_entropy( lm_logits, lm_labels) else: lm_loss = mpu.vocab_parallel_cross_entropy( lm_logits.float(), lm_labels) return lm_loss elif self.add_decoder and not self.add_encoder: decoder_output, encoder_output = lm_output return decoder_output else: encoder_output = lm_output return encoder_output
def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy): # Output. output = parallel_lm_logits(lm_output, logit_weights, parallel_output) if labels is None: return output else: if fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss
def CrossEntropy(output, labels): """ From pretrain_gpt2:forward_step() """ labels, loss_mask = labels[0], labels[1] losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() return loss
def forward_step(data_iterator, model): """Forward step.""" timers = get_timers() # Get the batch. timers('batch generator').start() tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ = get_batch(data_iterator) timers('batch generator').stop() # Forward model. lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), sentence_order.view(-1).contiguous(), ignore_index=-1) lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(), lm_labels.contiguous()) lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss = lm_loss + sop_loss reduced_losses = reduce_losses([lm_loss, sop_loss]) return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
def forward_step(batch, model, eval_metric): """Forward step.""" # Get the batch. tokens, labels, attention_mask, position_ids, loss_mask = process_batch( batch) # Forward model. output = model(tokens, position_ids, attention_mask) # For loss, return the unreduced loss. if eval_metric == 'loss': losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels.contiguous()) loss = torch.sum( losses.view(-1) * loss_mask.contiguous().view(-1).float()) return loss # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': outputs = torch.argmax(output, -1) correct = (outputs == labels).float() correct[(1 - loss_mask).bool()] = 1 correct = correct.prod(-1) return correct.sum() raise NotImplementedError('forward method for evaluation metric {} ' 'is not implemented.'.format(eval_metric))
def forward_step(batch, model, eval_metric): """Forward step.""" # Get the batch. tokens, labels, attention_mask, position_ids, loss_mask = process_batch( batch) # Tell the model what our actual batch size will be args = get_args() args.micro_batch_size = len(labels) # Forward model. if not mpu.is_pipeline_first_stage(): input_tensor, _ = communicate(tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) else: input_tensor = None # Forward pass through the model. if mpu.is_pipeline_first_stage(): assert input_tensor is None if mpu.is_pipeline_last_stage(): output = model(tokens, position_ids, attention_mask) else: output = model(tokens, position_ids, attention_mask) else: assert input_tensor is not None output = model(input_tensor, attention_mask) if not mpu.is_pipeline_last_stage(): communicate(tensor_send_next=output, tensor_send_prev=None, recv_forward=False, recv_backward=False) return None if mpu.is_pipeline_last_stage(): # For loss, return the unreduced loss. if eval_metric == 'loss': losses = mpu.vocab_parallel_cross_entropy( output.contiguous().float(), labels.contiguous()) loss = torch.sum( losses.view(-1) * loss_mask.contiguous().view(-1).float()) return loss # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': outputs = torch.argmax(output, -1) correct = (outputs == labels).float() correct[(1 - loss_mask).bool()] = 1 correct = correct.prod(-1) return correct.sum() raise NotImplementedError('forward method for evaluation metric {} ' 'is not implemented.'.format(eval_metric)) return None
def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): # Language model. lm_output = self.language_model(input_ids, position_ids, attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value) if get_key_value: lm_output, presents = lm_output # Output. parallel_output = self.parallel_output if forward_method_parallel_output is not None: parallel_output = forward_method_parallel_output if self.weight_tying: output = parallel_lm_logits( lm_output, self.language_model.embedding.word_embeddings.weight, parallel_output) else: output, bias = self.final_linear(lm_output) if get_key_value: output = [output, presents] if labels is None: return output else: if self.fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss
def cross_entropy(output, labels, _fp16=False): """From pretrain_gpt2:forward_step()""" """ if self.fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss """ labels, loss_mask = labels[0], labels[1] if _fp16: assert output.dtype == torch.half and loss_mask.dtype == torch.half losses = mpu.vocab_parallel_cross_entropy(output.contiguous(), labels) else: losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() return loss
def forward(self, input_ids, attention_mask, tokentype_ids=None, lm_labels=None): extended_attention_mask = bert_extended_attention_mask( attention_mask, next(self.language_model.parameters()).dtype) position_ids = bert_position_ids(input_ids) if self.add_binary_head: lm_output, pooled_output = self.language_model( input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) else: lm_output = self.language_model(input_ids, position_ids, extended_attention_mask, tokentype_ids=tokentype_ids) # Output. lm_logits = self.lm_head( lm_output, self.language_model.embedding.word_embeddings.weight) binary_logits = None if self.add_binary_head: binary_logits = self.binary_head(pooled_output) if lm_labels is None: return lm_logits, binary_logits else: if self.fp16_lm_cross_entropy: assert lm_logits.dtype == torch.half lm_loss = mpu.vocab_parallel_cross_entropy( lm_logits, lm_labels) else: lm_loss = mpu.vocab_parallel_cross_entropy( lm_logits.float(), lm_labels) return lm_loss, binary_logits
def post_language_model_processing(lm_output, pooled_output, lm_head, binary_head, lm_labels, logit_weights, fp16_lm_cross_entropy): # Output. lm_logits = lm_head( lm_output, logit_weights) binary_logits = None if binary_head is not None: binary_logits = binary_head(pooled_output) if lm_labels is None: return lm_logits, binary_logits else: if fp16_lm_cross_entropy: assert lm_logits.dtype == torch.half lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) else: lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels) return lm_loss, binary_logits
def forward_step(batch, model, eval_metric): """Forward step.""" # Get the batch. tokens, labels, attention_mask, position_ids, loss_mask = process_batch( batch) # Tell the model what our actual batch size will be args = get_args() args.micro_batch_size = len(labels) input_tensor = recv_forward() # Forward pass through the model. unwrapped_model = unwrap_model( model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) output = model(tokens, position_ids, attention_mask) send_forward(output) if mpu.is_pipeline_last_stage(): # For loss, return the unreduced loss. if eval_metric == 'loss': losses = mpu.vocab_parallel_cross_entropy( output.contiguous().float(), labels.contiguous()) loss = torch.sum( losses.view(-1) * loss_mask.contiguous().view(-1).float()) return loss # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': outputs = torch.argmax(output, -1) correct = (outputs == labels).float() correct[(1 - loss_mask).bool()] = 1 correct = correct.prod(-1) return correct.sum() raise NotImplementedError('forward method for evaluation metric {} ' 'is not implemented.'.format(eval_metric)) return None
def forward_step(data_iterator, model): """Forward step.""" timers = get_timers() # Get the batch. timers('batch generator').start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data_iterator) timers('batch generator').stop() # Forward model. output = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. reduced_loss = reduce_losses([loss]) return loss, {'lm loss': reduced_loss[0]}
def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None, curriculum_seqlen=None): args = get_args() if curriculum_seqlen is not None: args.curriculum_seqlen = curriculum_seqlen if curriculum_seqlen < input_ids.size()[1]: # seqlen-based curriculum learning # input_ids, position_ids, labels have size [batch size, seqlen] input_ids = input_ids[:, :curriculum_seqlen].contiguous() position_ids = position_ids[:, :curriculum_seqlen].contiguous() labels = labels[:, :curriculum_seqlen].contiguous() # attention_mask has size [1, 1, seqlen, seqlen] attention_mask = attention_mask[:, :, :curriculum_seqlen, : curriculum_seqlen].contiguous( ) else: if args.curriculum_learning: # If got a None input, need to reset curriculum_seqlen on user side args.curriculum_seqlen = args.seq_length # Language model. lm_output = self.language_model(input_ids, position_ids, attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value) if get_key_value: lm_output, presents = lm_output # Output. parallel_output = self.parallel_output if forward_method_parallel_output is not None: parallel_output = forward_method_parallel_output output = parallel_lm_logits( lm_output, self.language_model.embedding.word_embeddings.weight, parallel_output) if get_key_value: output = [output, presents] if labels is None: return output else: if self.fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss