Example #1
0
    def forward(self, model_input, attention_mask, tokentype_ids=None):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)

        kwargs = {}
        if mpu.is_pipeline_first_stage():
            input_ids = model_input
            position_ids = bert_position_ids(input_ids)

            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        if mpu.is_pipeline_last_stage():
            _, pooled_output = lm_output
            classification_output = self.classification_dropout(pooled_output)
            classification_logits = self.classification_head(
                classification_output)

            # Reshape back to separate choices.
            classification_logits = classification_logits.view(
                -1, self.num_classes)

            return classification_logits
        return lm_output
Example #2
0
    def forward(self, input_ids, attention_mask, tokentype_ids):

        # [batch, choices, sequence] --> [batch * choices, sequence] -->
        #    transformer --> [batch, choices] --> softmax

        # Ensure the shape is [batch-size, choices, sequence]
        assert len(input_ids.shape) == 3
        assert len(attention_mask.shape) == 3
        assert len(tokentype_ids.shape) == 3

        # Reshape and treat choice dimension the same as batch.
        num_choices = input_ids.shape[1]
        input_ids = input_ids.view(-1, input_ids.size(-1))
        attention_mask = attention_mask.view(-1, attention_mask.size(-1))
        tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))

        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        _, pooled_output = self.language_model(input_ids,
                                               position_ids,
                                               extended_attention_mask,
                                               tokentype_ids=tokentype_ids)

        # Output.
        multichoice_output = self.multichoice_dropout(pooled_output)
        multichoice_logits = self.multichoice_head(multichoice_output)

        # Reshape back to separate choices.
        multichoice_logits = multichoice_logits.view(-1, num_choices)

        return multichoice_logits
Example #3
0
def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    tokens, loss_mask, lm_labels, padding_mask, attention_mask, position_ids \
        = get_batch(data_iterator)
    timers('batch-generator').stop()

    extended_attention_mask = bert_extended_attention_mask(
        padding_mask) + attention_mask

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
            output_tensor = model(tokens,
                                  extended_attention_mask,
                                  tokentype_ids=None,
                                  lm_labels=lm_labels,
                                  position_ids=position_ids)
        else:
            output_tensor = model(tokens,
                                  extended_attention_mask,
                                  tokentype_ids=None)
    elif mpu.is_pipeline_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor,
                              extended_attention_mask,
                              lm_labels=lm_labels)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor,
                              extended_attention_mask,
                              position_ids=position_ids)

    if mpu.is_pipeline_last_stage():
        lm_loss_, _ = output_tensor

        lm_loss_ = lm_loss_.float()
        loss_mask = loss_mask.float()
        lm_loss = torch.sum(
            lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

        loss = lm_loss

        averaged_losses = average_losses_across_data_parallel_group([
            lm_loss,
        ])

        return loss, {'lm loss': averaged_losses[0]}
    return output_tensor
Example #4
0
    def forward(self, input_ids, attention_mask, token_type_ids):
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        sequence_output = self.language_model(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=extended_attention_mask,
            tokentype_ids=token_type_ids,
        )
        return sequence_output
    def forward(self, input_ids, attention_mask, token_type_ids):
        extended_attention_mask = bert_extended_attention_mask(
            attention_mask,
            next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

        sequence_output = self.language_model(input_ids,
                                              position_ids,
                                              extended_attention_mask,
                                              tokentype_ids=token_type_ids)
        return sequence_output
Example #6
0
    def forward(self, input_ids, attention_mask, token_type_ids):
        if self._lazy_init_fn is not None:
            self._lazy_init_fn()
            self._lazy_init_fn = None
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        sequence_output = self.language_model(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=extended_attention_mask,
            tokentype_ids=token_type_ids,
        )
        return sequence_output
Example #7
0
    def forward(self, input_ids, attention_mask, tokentype_ids=None):
        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

        lm_output, pooled_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids)

        # Output.
        ict_logits = self.ict_head(pooled_output)
        return ict_logits, None
Example #8
0
    def forward(self, input_ids, attention_mask, token_type_ids):
        app_state = AppState()
        if app_state.model_parallel_size is None:
            self.complete_lazy_init()

        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        sequence_output = self.language_model(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=extended_attention_mask,
            tokentype_ids=token_type_ids,
        )
        return sequence_output
Example #9
0
    def forward(self, input_ids, attention_mask, tokentype_ids):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        _, pooled_output = self.language_model(input_ids,
                                               position_ids,
                                               extended_attention_mask,
                                               tokentype_ids=tokentype_ids)

        # Output.
        classification_output = self.classification_dropout(pooled_output)
        classification_logits = self.classification_head(classification_output)

        # Reshape back to separate choices.
        classification_logits = classification_logits.view(-1, self.num_classes)

        return classification_logits
Example #10
0
    def forward(self, model_input, attention_mask, tokentype_ids=None):

        # [batch, choices, sequence] --> [batch * choices, sequence] -->
        #    transformer --> [batch, choices] --> softmax

        # Ensure the shape is [batch-size, choices, sequence]
        assert len(attention_mask.shape) == 3
        num_choices = attention_mask.shape[1]

        # Reshape and treat choice dimension the same as batch.
        attention_mask = attention_mask.view(-1, attention_mask.size(-1))
        extended_attention_mask = bert_extended_attention_mask(attention_mask)

        kwargs = {}
        if mpu.is_pipeline_first_stage():
            input_ids = model_input
            # Do the same as attention_mask for input_ids, tokentype_ids
            assert len(input_ids.shape) == 3
            assert len(tokentype_ids.shape) == 3
            input_ids = input_ids.view(-1, input_ids.size(-1))
            tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))

            position_ids = bert_position_ids(input_ids)
            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        if mpu.is_pipeline_last_stage():
            _, pooled_output = lm_output
            multichoice_output = self.multichoice_dropout(pooled_output)
            multichoice_logits = self.multichoice_head(multichoice_output)

            # Reshape back to separate choices.
            multichoice_logits = multichoice_logits.view(-1, num_choices)

            return multichoice_logits
        return lm_output