def forward(self, input_ids, position_ids, attention_mask): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings # Dropout. embeddings = self.embedding_dropout(embeddings) # Transformer. transformer_output, *moe_losses = self.transformer( embeddings, attention_mask) # Parallel logits. transformer_output_parallel = mpu.copy_to_model_parallel_region( transformer_output) logits_parallel = F.linear(transformer_output_parallel, self.word_embeddings.weight) if self.parallel_output: return (logits_parallel, *moe_losses) return (mpu.gather_from_model_parallel_region(logits_parallel), *moe_losses)
def forward(self, input_ids, position_ids, attention_mask): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings # Dropout. embeddings = self.embedding_dropout(embeddings) # Transformer. transformer_output = self.transformer(embeddings, attention_mask) # Parallel logits. transformer_output_parallel = mpu.copy_to_model_parallel_region( transformer_output) # logits_parallel = F.linear(transformer_output_parallel, # # self.word_embeddings.weight) pooler = self.linear(transformer_output_parallel) gpt_classifier_output = self.classifier(pooler) logits_parallel = gpt_classifier_output if self.parallel_output: return logits_parallel return mpu.gather_from_model_parallel_region(logits_parallel)
def forward(self, input_ids, position_ids, attention_mask, token_type_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = words_embeddings + position_embeddings + token_type_embeddings embeddings = self.input_layernorm(embeddings) # Dropout. embeddings = self.embedding_dropout(embeddings) # Transformer. transformer_output, *moe_losses = self.transformer( embeddings, attention_mask) # Parallel logits. transformer_output_parallel = mpu.copy_to_model_parallel_region( transformer_output) logits_parallel = F.linear(transformer_output_parallel, self.word_embeddings.weight) pooled_output = torch.squeeze(transformer_output_parallel[:, 0, :]) ############## #hrs_scores = self.hrs_head(pooled_output) #click_scores = self.click_head(pooled_output) ############# hrs_head0 = self.dense_hrs0(pooled_output) hrs_scores = self.hrs_head(torch.tanh(hrs_head0)) click_head0 = self.dense_click0(pooled_output) click_scores = self.click_head(torch.tanh(click_head0)) lpsat_head0 = self.dense_hrs0(pooled_output) lpsat_scores = self.hrs_head(torch.tanh(lpsat_head0)) qc_head0 = self.dense_hrs0(pooled_output) qc_scores = self.hrs_head(torch.tanh(qc_head0)) eff_head0 = self.dense_hrs0(pooled_output) eff_scores = self.hrs_head(torch.tanh(eff_head0)) local_head0 = self.dense_hrs0(pooled_output) local_scores = self.hrs_head(torch.tanh(local_head0)) fresh_head0 = self.dense_hrs0(pooled_output) fresh_scores = self.hrs_head(torch.tanh(fresh_head0)) ############# if self.parallel_output: return (logits_parallel, hrs_scores, click_scores, *moe_losses) return (mpu.gather_from_model_parallel_region(logits_parallel), hrs_scores, click_scores, *moe_losses)
def forward(self, input_ids, position_ids, attention_mask, *mems, return_memory=False, detach_memory=True, prompt_pos=None): # Embeddings. batch_size = input_ids.size(0) words_embeddings = self.word_embeddings(input_ids) embeddings = words_embeddings if prompt_pos is not None: embeddings = embeddings.clone() prompt_embeds = self.spell_embeddings.weight.unsqueeze(0) prompt_embeds = self.lstm_head(prompt_embeds)[0] prompt_embeds = self.mlp_head(prompt_embeds) batch_index = torch.arange(batch_size, device=input_ids.device).unsqueeze(1) embeddings[batch_index, prompt_pos] = prompt_embeds # Transformer. transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems, return_memory=return_memory, detach_memory=detach_memory) logits, hidden_layers = transformer_output outputs = hidden_layers if self.output_predict: # Parallel logits. logits_parallel = mpu.copy_to_model_parallel_region(logits) logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight) if self.parallel_output: return (logits_parallel, *outputs) return (mpu.gather_from_model_parallel_region(logits_parallel), *outputs) else: return (logits, *outputs)
def forward(self, input_ids, position_ids, attention_mask, *mems): # Embeddings. words_embeddings = self.word_embeddings(input_ids) embeddings = words_embeddings # Transformer. transformer_output = self.transformer(embeddings, position_ids, attention_mask, *mems) logits, *hidden_layers = transformer_output # Parallel logits. logits_parallel = mpu.copy_to_model_parallel_region(logits) logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight) if self.parallel_output: return (logits_parallel, *hidden_layers) return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)
def forward(self, source_ids, target_ids, source_position_ids, target_position_ids, source_mask, target_mask): # Embeddings. source_embeddings = self.word_embeddings(source_ids) target_embeddings = self.word_embeddings(target_ids) # Transformer. encoder_output, _ = self.encoder(source_embeddings, source_position_ids, source_mask) decoder_output, _ = self.decoder(target_embeddings, target_position_ids, target_mask) if self.output_predict: # Parallel logits. output_parallel = mpu.copy_to_model_parallel_region(decoder_output) logits_parallel = F.linear(output_parallel, self.word_embeddings.weight) if self.parallel_output: return (logits_parallel, ) return (mpu.gather_from_model_parallel_region(logits_parallel), ) else: return (decoder_output, )
def forward(self, input_ids, position_ids, attention_mask): # Embeddings. # print('input ids tensor', input_ids.size(), input_ids[0,:2]) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings # Dropout. embeddings = self.embedding_dropout(embeddings) # Transformer. transformer_output = self.transformer(embeddings, attention_mask) # Parallel logits. transformer_output_parallel = mpu.copy_to_model_parallel_region( transformer_output) logits_parallel = F.linear(transformer_output_parallel, self.word_embeddings.weight) if self.parallel_output: return logits_parallel return mpu.gather_from_model_parallel_region(logits_parallel)
def forward(self, input_ids, position_ids, attention_mask, layer_past=None, get_present=False, tokentype_ids=None): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings if tokentype_ids is not None: assert self.tokentype_embeddings is not None embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) else: assert self.tokentype_embeddings is None # Dropout. embeddings = self.embedding_dropout(embeddings) # Transformer. transformer_output = self.transformer(embeddings, attention_mask, layer_past=layer_past, get_present=get_present) if get_present: transformer_output, presents = transformer_output # Parallel logits. transformer_output_parallel = mpu.copy_to_model_parallel_region( transformer_output) logits_parallel = F.linear(transformer_output_parallel, self.word_embeddings.weight) if self.parallel_output: output = logits_parallel else: output = mpu.gather_from_model_parallel_region(logits_parallel) if get_present: output = [output, presents] return output