def get_feat_spec(self, max_seq_length): return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, pad_on_left=False, cls_token_segment_id=0, pad_token_segment_id=0, pad_token_id=0, pad_token_mask_id=0, sequence_a_segment_id=0, sequence_b_segment_id=1, sep_token_extra=False, )
def get_feat_spec(self, max_seq_length): # XLM-RoBERTa is weird # token 0 = '<s>' which is the cls_token # token 1 = '</s>' which is the sep_token # Also two '</s>'s are used between sentences. Yes, not '</s><s>'. return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, pad_on_left=False, cls_token_segment_id=0, pad_token_segment_id=0, pad_token_id=1, # XLM-RoBERTa uses pad_token_id = 1 pad_token_mask_id=0, sequence_a_segment_id=0, sequence_b_segment_id=0, # XLM-RoBERTa has no token_type_ids sep_token_extra=True, )
def build_featurization_spec(model_type, max_seq_length): model_arch = ModelArchitectures.from_model_type(model_type) if model_arch == ModelArchitectures.BERT: return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, pad_on_left=False, cls_token_segment_id=0, pad_token_segment_id=0, pad_token_id=0, pad_token_mask_id=0, sequence_a_segment_id=0, sequence_b_segment_id=1, sep_token_extra=False, ) elif model_arch == ModelArchitectures.XLM: return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, pad_on_left=False, cls_token_segment_id=0, pad_token_segment_id=0, pad_token_id=0, pad_token_mask_id=0, sequence_a_segment_id=0, sequence_b_segment_id=0, # RoBERTa has no token_type_ids sep_token_extra=False, ) elif model_arch == ModelArchitectures.ROBERTA: # RoBERTa is weird # token 0 = '<s>' which is the cls_token # token 1 = '</s>' which is the sep_token # Also two '</s>'s are used between sentences. Yes, not '</s><s>'. return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, pad_on_left=False, cls_token_segment_id=0, pad_token_segment_id=0, pad_token_id=1, # Roberta uses pad_token_id = 1 pad_token_mask_id=0, sequence_a_segment_id=0, sequence_b_segment_id=0, # RoBERTa has no token_type_ids sep_token_extra=True, ) elif model_arch == ModelArchitectures.ALBERT: # return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, # ? pad_on_left=False, # ok cls_token_segment_id=0, # ok pad_token_segment_id=0, # ok pad_token_id=0, # I think? pad_token_mask_id=0, # I think? sequence_a_segment_id=0, # I think? sequence_b_segment_id=1, # I think? sep_token_extra=False, ) elif model_arch == ModelArchitectures.XLM_ROBERTA: # XLM-RoBERTa is weird # token 0 = '<s>' which is the cls_token # token 1 = '</s>' which is the sep_token # Also two '</s>'s are used between sentences. Yes, not '</s><s>'. return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, pad_on_left=False, cls_token_segment_id=0, pad_token_segment_id=0, pad_token_id=1, # XLM-RoBERTa uses pad_token_id = 1 pad_token_mask_id=0, sequence_a_segment_id=0, sequence_b_segment_id=0, # XLM-RoBERTa has no token_type_ids sep_token_extra=True, ) elif model_arch == ModelArchitectures.BART: # BART is weird # token 0 = '<s>' which is the cls_token # token 1 = '</s>' which is the sep_token # Also two '</s>'s are used between sentences. Yes, not '</s><s>'. return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, pad_on_left=False, cls_token_segment_id=0, pad_token_segment_id=0, pad_token_id=1, # BART uses pad_token_id = 1 pad_token_mask_id=0, sequence_a_segment_id=0, sequence_b_segment_id=0, # BART has no token_type_ids sep_token_extra=True, ) elif model_arch == ModelArchitectures.MBART: # mBART is weird # token 0 = '<s>' which is the cls_token # token 1 = '</s>' which is the sep_token # Also two '</s>'s are used between sentences. Yes, not '</s><s>'. return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, pad_on_left=False, cls_token_segment_id=0, pad_token_segment_id=0, pad_token_id=1, # mBART uses pad_token_id = 1 pad_token_mask_id=0, sequence_a_segment_id=0, sequence_b_segment_id=0, # mBART has no token_type_ids sep_token_extra=True, ) elif model_arch == ModelArchitectures.ELECTRA: return FeaturizationSpec( max_seq_length=max_seq_length, cls_token_at_end=False, pad_on_left=False, cls_token_segment_id=0, pad_token_segment_id=0, pad_token_id=0, pad_token_mask_id=0, sequence_a_segment_id=0, sequence_b_segment_id=1, sep_token_extra=False, ) else: raise KeyError(model_arch)