Beispiel #1
0
    def forward(self, input_ids, position_ids, attention_mask):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output, *moe_losses = self.transformer(
            embeddings, attention_mask)

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        logits_parallel = F.linear(transformer_output_parallel,
                                   self.word_embeddings.weight)

        if self.parallel_output:
            return (logits_parallel, *moe_losses)

        return (mpu.gather_from_model_parallel_region(logits_parallel),
                *moe_losses)
    def forward(self, input_ids, position_ids, attention_mask):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output = self.transformer(embeddings, attention_mask)

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        # logits_parallel = F.linear(transformer_output_parallel,
        #
        #                    self.word_embeddings.weight)

        pooler = self.linear(transformer_output_parallel)
        gpt_classifier_output = self.classifier(pooler)

        logits_parallel = gpt_classifier_output
        if self.parallel_output:
            return logits_parallel

        return mpu.gather_from_model_parallel_region(logits_parallel)
Beispiel #3
0
    def forward(self, input_ids, position_ids, attention_mask, token_type_ids):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = words_embeddings + position_embeddings + token_type_embeddings

        embeddings = self.input_layernorm(embeddings)
        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output, *moe_losses = self.transformer(
            embeddings, attention_mask)

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        logits_parallel = F.linear(transformer_output_parallel,
                                   self.word_embeddings.weight)

        pooled_output = torch.squeeze(transformer_output_parallel[:, 0, :])
        ##############
        #hrs_scores = self.hrs_head(pooled_output)
        #click_scores = self.click_head(pooled_output)
        #############
        hrs_head0 = self.dense_hrs0(pooled_output)
        hrs_scores = self.hrs_head(torch.tanh(hrs_head0))

        click_head0 = self.dense_click0(pooled_output)
        click_scores = self.click_head(torch.tanh(click_head0))

        lpsat_head0 = self.dense_hrs0(pooled_output)
        lpsat_scores = self.hrs_head(torch.tanh(lpsat_head0))

        qc_head0 = self.dense_hrs0(pooled_output)
        qc_scores = self.hrs_head(torch.tanh(qc_head0))

        eff_head0 = self.dense_hrs0(pooled_output)
        eff_scores = self.hrs_head(torch.tanh(eff_head0))

        local_head0 = self.dense_hrs0(pooled_output)
        local_scores = self.hrs_head(torch.tanh(local_head0))

        fresh_head0 = self.dense_hrs0(pooled_output)
        fresh_scores = self.hrs_head(torch.tanh(fresh_head0))
        #############
        if self.parallel_output:
            return (logits_parallel, hrs_scores, click_scores, *moe_losses)

        return (mpu.gather_from_model_parallel_region(logits_parallel),
                hrs_scores, click_scores, *moe_losses)
Beispiel #4
0
    def forward(self,
                input_ids,
                position_ids,
                attention_mask,
                *mems,
                return_memory=False,
                detach_memory=True,
                prompt_pos=None):
        # Embeddings.
        batch_size = input_ids.size(0)
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings
        if prompt_pos is not None:
            embeddings = embeddings.clone()
            prompt_embeds = self.spell_embeddings.weight.unsqueeze(0)
            prompt_embeds = self.lstm_head(prompt_embeds)[0]
            prompt_embeds = self.mlp_head(prompt_embeds)
            batch_index = torch.arange(batch_size,
                                       device=input_ids.device).unsqueeze(1)
            embeddings[batch_index, prompt_pos] = prompt_embeds
        # Transformer.
        transformer_output = self.transformer(embeddings,
                                              position_ids,
                                              attention_mask,
                                              mems,
                                              return_memory=return_memory,
                                              detach_memory=detach_memory)
        logits, hidden_layers = transformer_output
        outputs = hidden_layers

        if self.output_predict:
            # Parallel logits.
            logits_parallel = mpu.copy_to_model_parallel_region(logits)
            logits_parallel = F.linear(logits_parallel,
                                       self.word_embeddings.weight)

            if self.parallel_output:
                return (logits_parallel, *outputs)

            return (mpu.gather_from_model_parallel_region(logits_parallel),
                    *outputs)
        else:
            return (logits, *outputs)
    def forward(self, input_ids, position_ids, attention_mask, *mems):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings

        # Transformer.
        transformer_output = self.transformer(embeddings, position_ids,
                                              attention_mask, *mems)
        logits, *hidden_layers = transformer_output
        # Parallel logits.
        logits_parallel = mpu.copy_to_model_parallel_region(logits)
        logits_parallel = F.linear(logits_parallel,
                                   self.word_embeddings.weight)

        if self.parallel_output:
            return (logits_parallel, *hidden_layers)

        return (mpu.gather_from_model_parallel_region(logits_parallel),
                *hidden_layers)
Beispiel #6
0
    def forward(self, source_ids, target_ids, source_position_ids,
                target_position_ids, source_mask, target_mask):
        # Embeddings.
        source_embeddings = self.word_embeddings(source_ids)
        target_embeddings = self.word_embeddings(target_ids)

        # Transformer.
        encoder_output, _ = self.encoder(source_embeddings,
                                         source_position_ids, source_mask)
        decoder_output, _ = self.decoder(target_embeddings,
                                         target_position_ids, target_mask)
        if self.output_predict:
            # Parallel logits.
            output_parallel = mpu.copy_to_model_parallel_region(decoder_output)
            logits_parallel = F.linear(output_parallel,
                                       self.word_embeddings.weight)

            if self.parallel_output:
                return (logits_parallel, )

            return (mpu.gather_from_model_parallel_region(logits_parallel), )
        else:
            return (decoder_output, )
Beispiel #7
0
    def forward(self, input_ids, position_ids, attention_mask):

        # Embeddings.
        #         print('input ids tensor', input_ids.size(), input_ids[0,:2])
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output = self.transformer(embeddings, attention_mask)

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        logits_parallel = F.linear(transformer_output_parallel,
                                   self.word_embeddings.weight)

        if self.parallel_output:
            return logits_parallel

        return mpu.gather_from_model_parallel_region(logits_parallel)
Beispiel #8
0
    def forward(self, input_ids, position_ids, attention_mask,
                layer_past=None, get_present=False, tokentype_ids=None):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output = self.transformer(embeddings, attention_mask,
                                              layer_past=layer_past,
                                              get_present=get_present)
        if get_present:
            transformer_output, presents = transformer_output

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        logits_parallel = F.linear(transformer_output_parallel,
                                   self.word_embeddings.weight)

        if self.parallel_output:
            output = logits_parallel
        else:
            output = mpu.gather_from_model_parallel_region(logits_parallel)
        if get_present:
            output = [output, presents]
        return output