Beispiel #1
0
 def get_feat_spec(self, max_seq_length):
     return FeaturizationSpec(
         max_seq_length=max_seq_length,
         cls_token_at_end=False,
         pad_on_left=False,
         cls_token_segment_id=0,
         pad_token_segment_id=0,
         pad_token_id=0,
         pad_token_mask_id=0,
         sequence_a_segment_id=0,
         sequence_b_segment_id=1,
         sep_token_extra=False,
     )
Beispiel #2
0
 def get_feat_spec(self, max_seq_length):
     # XLM-RoBERTa is weird
     # token 0 = '<s>' which is the cls_token
     # token 1 = '</s>' which is the sep_token
     # Also two '</s>'s are used between sentences. Yes, not '</s><s>'.
     return FeaturizationSpec(
         max_seq_length=max_seq_length,
         cls_token_at_end=False,
         pad_on_left=False,
         cls_token_segment_id=0,
         pad_token_segment_id=0,
         pad_token_id=1,  # XLM-RoBERTa uses pad_token_id = 1
         pad_token_mask_id=0,
         sequence_a_segment_id=0,
         sequence_b_segment_id=0,  # XLM-RoBERTa has no token_type_ids
         sep_token_extra=True,
     )
Beispiel #3
0
def build_featurization_spec(model_type, max_seq_length):
    model_arch = ModelArchitectures.from_model_type(model_type)
    if model_arch == ModelArchitectures.BERT:
        return FeaturizationSpec(
            max_seq_length=max_seq_length,
            cls_token_at_end=False,
            pad_on_left=False,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            pad_token_id=0,
            pad_token_mask_id=0,
            sequence_a_segment_id=0,
            sequence_b_segment_id=1,
            sep_token_extra=False,
        )
    elif model_arch == ModelArchitectures.XLM:
        return FeaturizationSpec(
            max_seq_length=max_seq_length,
            cls_token_at_end=False,
            pad_on_left=False,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            pad_token_id=0,
            pad_token_mask_id=0,
            sequence_a_segment_id=0,
            sequence_b_segment_id=0,  # RoBERTa has no token_type_ids
            sep_token_extra=False,
        )
    elif model_arch == ModelArchitectures.ROBERTA:
        # RoBERTa is weird
        # token 0 = '<s>' which is the cls_token
        # token 1 = '</s>' which is the sep_token
        # Also two '</s>'s are used between sentences. Yes, not '</s><s>'.
        return FeaturizationSpec(
            max_seq_length=max_seq_length,
            cls_token_at_end=False,
            pad_on_left=False,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            pad_token_id=1,  # Roberta uses pad_token_id = 1
            pad_token_mask_id=0,
            sequence_a_segment_id=0,
            sequence_b_segment_id=0,  # RoBERTa has no token_type_ids
            sep_token_extra=True,
        )
    elif model_arch == ModelArchitectures.ALBERT:
        #
        return FeaturizationSpec(
            max_seq_length=max_seq_length,
            cls_token_at_end=False,  # ?
            pad_on_left=False,  # ok
            cls_token_segment_id=0,  # ok
            pad_token_segment_id=0,  # ok
            pad_token_id=0,  # I think?
            pad_token_mask_id=0,  # I think?
            sequence_a_segment_id=0,  # I think?
            sequence_b_segment_id=1,  # I think?
            sep_token_extra=False,
        )
    elif model_arch == ModelArchitectures.XLM_ROBERTA:
        # XLM-RoBERTa is weird
        # token 0 = '<s>' which is the cls_token
        # token 1 = '</s>' which is the sep_token
        # Also two '</s>'s are used between sentences. Yes, not '</s><s>'.
        return FeaturizationSpec(
            max_seq_length=max_seq_length,
            cls_token_at_end=False,
            pad_on_left=False,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            pad_token_id=1,  # XLM-RoBERTa uses pad_token_id = 1
            pad_token_mask_id=0,
            sequence_a_segment_id=0,
            sequence_b_segment_id=0,  # XLM-RoBERTa has no token_type_ids
            sep_token_extra=True,
        )
    elif model_arch == ModelArchitectures.BART:
        # BART is weird
        # token 0 = '<s>' which is the cls_token
        # token 1 = '</s>' which is the sep_token
        # Also two '</s>'s are used between sentences. Yes, not '</s><s>'.
        return FeaturizationSpec(
            max_seq_length=max_seq_length,
            cls_token_at_end=False,
            pad_on_left=False,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            pad_token_id=1,  # BART uses pad_token_id = 1
            pad_token_mask_id=0,
            sequence_a_segment_id=0,
            sequence_b_segment_id=0,  # BART has no token_type_ids
            sep_token_extra=True,
        )
    elif model_arch == ModelArchitectures.MBART:
        # mBART is weird
        # token 0 = '<s>' which is the cls_token
        # token 1 = '</s>' which is the sep_token
        # Also two '</s>'s are used between sentences. Yes, not '</s><s>'.
        return FeaturizationSpec(
            max_seq_length=max_seq_length,
            cls_token_at_end=False,
            pad_on_left=False,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            pad_token_id=1,  # mBART uses pad_token_id = 1
            pad_token_mask_id=0,
            sequence_a_segment_id=0,
            sequence_b_segment_id=0,  # mBART has no token_type_ids
            sep_token_extra=True,
        )
    elif model_arch == ModelArchitectures.ELECTRA:
        return FeaturizationSpec(
            max_seq_length=max_seq_length,
            cls_token_at_end=False,
            pad_on_left=False,
            cls_token_segment_id=0,
            pad_token_segment_id=0,
            pad_token_id=0,
            pad_token_mask_id=0,
            sequence_a_segment_id=0,
            sequence_b_segment_id=1,
            sep_token_extra=False,
        )
    else:
        raise KeyError(model_arch)