Пример #1
0
    def forward(self, model_input, attention_mask, tokentype_ids=None):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)

        kwargs = {}
        if mpu.is_pipeline_first_stage():
            input_ids = model_input
            position_ids = bert_position_ids(input_ids)

            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        if mpu.is_pipeline_last_stage():
            _, pooled_output = lm_output
            classification_output = self.classification_dropout(pooled_output)
            classification_logits = self.classification_head(
                classification_output)

            # Reshape back to separate choices.
            classification_logits = classification_logits.view(
                -1, self.num_classes)

            return classification_logits
        return lm_output
Пример #2
0
    def forward(self, input_ids, attention_mask, tokentype_ids):

        # [batch, choices, sequence] --> [batch * choices, sequence] -->
        #    transformer --> [batch, choices] --> softmax

        # Ensure the shape is [batch-size, choices, sequence]
        assert len(input_ids.shape) == 3
        assert len(attention_mask.shape) == 3
        assert len(tokentype_ids.shape) == 3

        # Reshape and treat choice dimension the same as batch.
        num_choices = input_ids.shape[1]
        input_ids = input_ids.view(-1, input_ids.size(-1))
        attention_mask = attention_mask.view(-1, attention_mask.size(-1))
        tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))

        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        _, pooled_output = self.language_model(input_ids,
                                               position_ids,
                                               extended_attention_mask,
                                               tokentype_ids=tokentype_ids)

        # Output.
        multichoice_output = self.multichoice_dropout(pooled_output)
        multichoice_logits = self.multichoice_head(multichoice_output)

        # Reshape back to separate choices.
        multichoice_logits = multichoice_logits.view(-1, num_choices)

        return multichoice_logits
Пример #3
0
    def forward(self, input_ids, attention_mask, token_type_ids):
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        sequence_output = self.language_model(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=extended_attention_mask,
            tokentype_ids=token_type_ids,
        )
        return sequence_output
Пример #4
0
    def forward(self, input_ids, attention_mask, token_type_ids):
        extended_attention_mask = bert_extended_attention_mask(
            attention_mask,
            next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

        sequence_output = self.language_model(input_ids,
                                              position_ids,
                                              extended_attention_mask,
                                              tokentype_ids=token_type_ids)
        return sequence_output
Пример #5
0
    def forward(self, input_ids, attention_mask, token_type_ids):
        if self._lazy_init_fn is not None:
            self._lazy_init_fn()
            self._lazy_init_fn = None
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        sequence_output = self.language_model(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=extended_attention_mask,
            tokentype_ids=token_type_ids,
        )
        return sequence_output
Пример #6
0
    def forward(self, input_ids, attention_mask, tokentype_ids=None):
        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

        lm_output, pooled_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids)

        # Output.
        ict_logits = self.ict_head(pooled_output)
        return ict_logits, None
Пример #7
0
    def forward(self, input_ids, attention_mask, token_type_ids):
        app_state = AppState()
        if app_state.model_parallel_size is None:
            self.complete_lazy_init()

        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        sequence_output = self.language_model(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=extended_attention_mask,
            tokentype_ids=token_type_ids,
        )
        return sequence_output
Пример #8
0
    def forward(self, input_ids, attention_mask, tokentype_ids):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        _, pooled_output = self.language_model(input_ids,
                                               position_ids,
                                               extended_attention_mask,
                                               tokentype_ids=tokentype_ids)

        # Output.
        classification_output = self.classification_dropout(pooled_output)
        classification_logits = self.classification_head(classification_output)

        # Reshape back to separate choices.
        classification_logits = classification_logits.view(-1, self.num_classes)

        return classification_logits
Пример #9
0
    def forward(self, model_input, attention_mask, tokentype_ids=None):

        # [batch, choices, sequence] --> [batch * choices, sequence] -->
        #    transformer --> [batch, choices] --> softmax

        # Ensure the shape is [batch-size, choices, sequence]
        assert len(attention_mask.shape) == 3
        num_choices = attention_mask.shape[1]

        # Reshape and treat choice dimension the same as batch.
        attention_mask = attention_mask.view(-1, attention_mask.size(-1))
        extended_attention_mask = bert_extended_attention_mask(attention_mask)

        kwargs = {}
        if mpu.is_pipeline_first_stage():
            input_ids = model_input
            # Do the same as attention_mask for input_ids, tokentype_ids
            assert len(input_ids.shape) == 3
            assert len(tokentype_ids.shape) == 3
            input_ids = input_ids.view(-1, input_ids.size(-1))
            tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))

            position_ids = bert_position_ids(input_ids)
            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        if mpu.is_pipeline_last_stage():
            _, pooled_output = lm_output
            multichoice_output = self.multichoice_dropout(pooled_output)
            multichoice_logits = self.multichoice_head(multichoice_output)

            # Reshape back to separate choices.
            multichoice_logits = multichoice_logits.view(-1, num_choices)

            return multichoice_logits
        return lm_output
Пример #10
0
    def forward(self, input_ids, attention_mask, tokentype_ids=None):
        extended_attention_mask = attention_mask.unsqueeze(1)
        #extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        lm_output = self.language_model(input_ids,
                                        position_ids,
                                        extended_attention_mask,
                                        tokentype_ids=tokentype_ids)
        # This mask will be used in average-pooling and max-pooling
        pool_mask = (input_ids == self.pad_id).unsqueeze(2)

        # Taking the representation of the [CLS] token of BERT
        pooled_output = lm_output[:, 0, :]

        # Converting to float16 dtype
        pooled_output = pooled_output.to(lm_output.dtype)

        # Output.
        if self.biencoder_projection_dim:
            pooled_output = self.projection_enc(pooled_output)

        return pooled_output