Ejemplo n.º 1
0
    def forward(self, input_ids, position_ids, attention_mask):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output = self.transformer(embeddings, attention_mask)

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        # logits_parallel = F.linear(transformer_output_parallel,
        #
        #                    self.word_embeddings.weight)

        pooler = self.linear(transformer_output_parallel)
        gpt_classifier_output = self.classifier(pooler)

        logits_parallel = gpt_classifier_output
        if self.parallel_output:
            return logits_parallel

        return mpu.gather_from_model_parallel_region(logits_parallel)
Ejemplo n.º 2
0
    def forward(self, input_ids, position_ids, attention_mask):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output, *moe_losses = self.transformer(
            embeddings, attention_mask)

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        logits_parallel = F.linear(transformer_output_parallel,
                                   self.word_embeddings.weight)

        if self.parallel_output:
            return (logits_parallel, *moe_losses)

        return (mpu.gather_from_model_parallel_region(logits_parallel),
                *moe_losses)
Ejemplo n.º 3
0
    def forward(self, input_ids, position_ids, attention_mask, token_type_ids):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = words_embeddings + position_embeddings + token_type_embeddings

        embeddings = self.input_layernorm(embeddings)
        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output, *moe_losses = self.transformer(
            embeddings, attention_mask)

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        logits_parallel = F.linear(transformer_output_parallel,
                                   self.word_embeddings.weight)

        pooled_output = torch.squeeze(transformer_output_parallel[:, 0, :])
        ##############
        #hrs_scores = self.hrs_head(pooled_output)
        #click_scores = self.click_head(pooled_output)
        #############
        hrs_head0 = self.dense_hrs0(pooled_output)
        hrs_scores = self.hrs_head(torch.tanh(hrs_head0))

        click_head0 = self.dense_click0(pooled_output)
        click_scores = self.click_head(torch.tanh(click_head0))

        lpsat_head0 = self.dense_hrs0(pooled_output)
        lpsat_scores = self.hrs_head(torch.tanh(lpsat_head0))

        qc_head0 = self.dense_hrs0(pooled_output)
        qc_scores = self.hrs_head(torch.tanh(qc_head0))

        eff_head0 = self.dense_hrs0(pooled_output)
        eff_scores = self.hrs_head(torch.tanh(eff_head0))

        local_head0 = self.dense_hrs0(pooled_output)
        local_scores = self.hrs_head(torch.tanh(local_head0))

        fresh_head0 = self.dense_hrs0(pooled_output)
        fresh_scores = self.hrs_head(torch.tanh(fresh_head0))
        #############
        if self.parallel_output:
            return (logits_parallel, hrs_scores, click_scores, *moe_losses)

        return (mpu.gather_from_model_parallel_region(logits_parallel),
                hrs_scores, click_scores, *moe_losses)
Ejemplo n.º 4
0
    def forward(self,
                input_ids,
                position_ids,
                attention_mask,
                *mems,
                return_memory=False,
                detach_memory=True,
                prompt_pos=None):
        # Embeddings.
        batch_size = input_ids.size(0)
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings
        if prompt_pos is not None:
            embeddings = embeddings.clone()
            prompt_embeds = self.spell_embeddings.weight.unsqueeze(0)
            prompt_embeds = self.lstm_head(prompt_embeds)[0]
            prompt_embeds = self.mlp_head(prompt_embeds)
            batch_index = torch.arange(batch_size,
                                       device=input_ids.device).unsqueeze(1)
            embeddings[batch_index, prompt_pos] = prompt_embeds
        # Transformer.
        transformer_output = self.transformer(embeddings,
                                              position_ids,
                                              attention_mask,
                                              mems,
                                              return_memory=return_memory,
                                              detach_memory=detach_memory)
        logits, hidden_layers = transformer_output
        outputs = hidden_layers

        if self.output_predict:
            # Parallel logits.
            logits_parallel = mpu.copy_to_model_parallel_region(logits)
            logits_parallel = F.linear(logits_parallel,
                                       self.word_embeddings.weight)

            if self.parallel_output:
                return (logits_parallel, *outputs)

            return (mpu.gather_from_model_parallel_region(logits_parallel),
                    *outputs)
        else:
            return (logits, *outputs)
    def forward(self, input_ids, position_ids, attention_mask, *mems):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings

        # Transformer.
        transformer_output = self.transformer(embeddings, position_ids,
                                              attention_mask, *mems)
        logits, *hidden_layers = transformer_output
        # Parallel logits.
        logits_parallel = mpu.copy_to_model_parallel_region(logits)
        logits_parallel = F.linear(logits_parallel,
                                   self.word_embeddings.weight)

        if self.parallel_output:
            return (logits_parallel, *hidden_layers)

        return (mpu.gather_from_model_parallel_region(logits_parallel),
                *hidden_layers)
Ejemplo n.º 6
0
    def forward(self, source_ids, target_ids, source_position_ids,
                target_position_ids, source_mask, target_mask):
        # Embeddings.
        source_embeddings = self.word_embeddings(source_ids)
        target_embeddings = self.word_embeddings(target_ids)

        # Transformer.
        encoder_output, _ = self.encoder(source_embeddings,
                                         source_position_ids, source_mask)
        decoder_output, _ = self.decoder(target_embeddings,
                                         target_position_ids, target_mask)
        if self.output_predict:
            # Parallel logits.
            output_parallel = mpu.copy_to_model_parallel_region(decoder_output)
            logits_parallel = F.linear(output_parallel,
                                       self.word_embeddings.weight)

            if self.parallel_output:
                return (logits_parallel, )

            return (mpu.gather_from_model_parallel_region(logits_parallel), )
        else:
            return (decoder_output, )
Ejemplo n.º 7
0
    def forward(self, input_ids, position_ids, attention_mask):

        # Embeddings.
        #         print('input ids tensor', input_ids.size(), input_ids[0,:2])
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output = self.transformer(embeddings, attention_mask)

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        logits_parallel = F.linear(transformer_output_parallel,
                                   self.word_embeddings.weight)

        if self.parallel_output:
            return logits_parallel

        return mpu.gather_from_model_parallel_region(logits_parallel)
Ejemplo n.º 8
0
    def forward(self, input_ids, position_ids, attention_mask,
                layer_past=None, get_present=False, tokentype_ids=None):

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None:
            assert self.tokentype_embeddings is not None
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
        else:
            assert self.tokentype_embeddings is None

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        # Transformer.
        transformer_output = self.transformer(embeddings, attention_mask,
                                              layer_past=layer_past,
                                              get_present=get_present)
        if get_present:
            transformer_output, presents = transformer_output

        # Parallel logits.
        transformer_output_parallel = mpu.copy_to_model_parallel_region(
            transformer_output)
        logits_parallel = F.linear(transformer_output_parallel,
                                   self.word_embeddings.weight)

        if self.parallel_output:
            output = logits_parallel
        else:
            output = mpu.gather_from_model_parallel_region(logits_parallel)
        if get_present:
            output = [output, presents]
        return output
Ejemplo n.º 9
0
def lm_forward_step(data, model, args, timers, mems, eval_metric=None):
    """Forward step."""
    # Get the batch.
    if timers is not None:
        timers('batch generator').start()
    try:
        data = next(data)
    except BaseException:
        data = data

    if 'mask' in data:
        # finetune SQuAD
        data['attention_mask'] = data.pop('mask')
        data['position_id'] = data.pop('position')
        data['loss_mask'] = data.pop('logit_mask')

    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data, args)
    if timers is not None:
        timers('batch generator').stop()

    if tokens.dim() == 3:
        tokens = tokens.squeeze(1)
        labels = labels.squeeze(1)
        loss_mask = loss_mask.squeeze(1)
        attention_mask = attention_mask.squeeze(1)
        position_ids = position_ids.squeeze(1)

    def print_masked_text(batch_id):
        block_position_ids = position_ids[:, 1]
        position_ids_ = position_ids[:, 0]
        output_tokens = []
        sep = attention_mask[batch_id].item()
        for i, token in enumerate(tokens[batch_id, :sep].tolist()):
            if global_tokenizer is not None:
                token = global_tokenizer.IdToToken(token)
                if token.startswith('[MASK'):
                    token = f"[{position_ids_[batch_id, i].item()}, {token}]"
                if token.startswith('##') and len(output_tokens) > 0 and not output_tokens[-1].endswith(']'):
                    output_tokens[-1] += token[2:]
                else:
                    output_tokens.append(token)
            else:
                output_tokens.append(str(token))
        print(" ".join(output_tokens))
        last_index = None
        for i in range(sep, tokens.size(1)):
            if global_tokenizer.IdToToken(tokens[batch_id, i].item()).startswith("<|startofpiece"):
                if last_index is not None:
                    print(global_tokenizer.DecodeIds(tokens[batch_id, last_index: i].tolist()), "|",
                          global_tokenizer.DecodeIds(labels[batch_id, last_index: i].tolist())),
                    print(position_ids_[batch_id, last_index: i].tolist(),
                          block_position_ids[batch_id, last_index:i].tolist())
                last_index = i
        if last_index is not None:
            print(global_tokenizer.DecodeIds(tokens[batch_id, last_index:].tolist()), "|",
                  global_tokenizer.DecodeIds(labels[batch_id, last_index:].tolist()))
            print(position_ids_[batch_id, last_index:].tolist(), block_position_ids[batch_id, last_index:].tolist())

    # Forward model.
    if args.continuous_prompt:
        prompt_pos = data["prompt_pos"].long().cuda()
        logits, *mems = model(tokens, position_ids, attention_mask, *mems, prompt_pos=prompt_pos)
    else:
        logits, *mems = model(tokens, position_ids, attention_mask, *mems)
        
    if eval_metric is None or eval_metric == 'loss':
        losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels)
        loss_mask = loss_mask.view(-1)
        # The loss is not normalized for fair comparison
        loss = torch.sum(losses.view(-1) * loss_mask)
        if eval_metric is None:
            loss = loss / loss_mask.sum()
        return loss, mems, 'bert'
    elif eval_metric == 'accuracy' or eval_metric == 'classify':
        logits = mpu.gather_from_model_parallel_region(logits)
        outputs = torch.argmax(logits, -1)
        correct = (outputs == labels).float()
        correct[(1 - loss_mask).bool()] = 1
        correct = correct.prod(-1)
        if eval_metric == 'accuracy':
            correct = correct.sum()
        return correct, mems, 'bert'
    else:
        raise NotImplementedError("Metric {} not implemented".format(eval_metric))