def load( cls, model_name_or_path: Union[HFModelResult, FlairModelResult, str] ) -> AdaptiveModel: """Class method for loading a constructing this classifier * **model_name_or_path** - A key string of one of Flair's pre-trained Sequence Classifier Model or a `HFModelResult` """ if risinstance([HFModelResult, FlairModelResult], model_name_or_path): model_name_or_path = model_name_or_path.name classifier = cls(model_name_or_path) return classifier
def _mutate_model_head(self, class_label: ClassLabel) -> None: """Manually intialize new linear layers for prediction heads on specific language models that we're trying to train on""" if risinstance([BertPreTrainedModel, DistilBertPreTrainedModel], self.model): self.model.classifier = nn.Linear(self.model.config.hidden_size, class_label.num_classes) self.model.num_labels = class_label.num_classes elif isinstance(self.model, XLMPreTrainedModel): self.model.num_labels = class_label.num_classes elif isinstance(self.model, XLNetPreTrainedModel): self.model.logits_proj = nn.Linear(self.model.config.d_model, class_label.num_classes) self.model.num_labels = class_label.num_classes elif isinstance(self.model, ElectraPreTrainedModel): self.model.num_labels = class_label.num_classes else: logger.info( f'Sorry, can not train on a model of type {type(self.model)}')
def before_batch(self): "Adjusts `token_type_ids` if model is in `_qa_models`" if risinstance(self._qa_models, self.learn.model): del self.learn.inputs["token_type_ids"] if len(self.xb) > 3: self.example_indices = self.xb[3] if isinstance(self.learn.model, self.xmodel_instances): self.learn.inputs.update({ 'cls_index': self.xb[4], 'p_mask': self.xb[5] }) # for lang_id-sensitive xlm models if nested_attr(self.learn.model, 'config.lang2id', False): # Set language id as 0 for now self.learn.inputs.update({ 'langs': (torch.ones(self.xb[0].shape, dtype=torch.int64) * 0) })
def predict( self, text: Union[List[Sentence], Sentence, List[str], str], mini_batch_size: int = 32, **kwargs, ) -> List[Sentence]: """Predict method for running inference using the pre-trained sequence classifier model * **text** - String, list of strings, sentences, or list of sentences to run inference on * **mini_batch_size** - Mini batch size * ****kwargs**(Optional) - Optional arguments for the Transformers classifier """ id2label = self.model.config.id2label sentences = text results: List[Sentence] = [] if not sentences: return sentences if risinstance([DataPoint, str], sentences): sentences = [sentences] # filter empty sentences if isinstance(sentences[0], Sentence): sentences = [ sentence for sentence in sentences if len(sentence) > 0 ] if len(sentences) == 0: return sentences # reverse sort all sequences by their length rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True) original_order_index = sorted(range(len(rev_order_len_index)), key=lambda k: rev_order_len_index[k]) reordered_sentences: List[Union[DataPoint, str]] = [ sentences[index] for index in rev_order_len_index ] # Turn all Sentence objects into strings if isinstance(reordered_sentences[0], Sentence): str_reordered_sentences = [ sentence.to_original_text() for sentence in sentences ] else: str_reordered_sentences = reordered_sentences dataset = self._tokenize(str_reordered_sentences) dl = DataLoader(dataset, batch_size=mini_batch_size) predictions: List[Tuple[str, float]] = [] outputs, _ = super().get_preds(dl=dl) logits = torch.cat([o['logits'] for o in outputs]) predictions = torch.softmax(logits, dim=1).tolist() for text, pred in zip(str_reordered_sentences, predictions): # Initialize and assign labels to each class in each datapoint prediction text_sent = Sentence(text) for k, v in id2label.items(): text_sent.add_label(label_type='sc', value=v, score=pred[k]) results.append(text_sent) # Order results back into original order results = [results[index] for index in original_order_index] return results
def tag_text( self, text: Union[List[Sentence], Sentence, List[str], str], model_name_or_path: Union[str, FlairModelResult, HFModelResult] = "ner-ontonotes", mini_batch_size: int = 32, **kwargs, ) -> List[Sentence]: """Tags tokens with labels the token classification models have been trained on * **text** - Text input, it can be a string or any of Flair's `Sentence` input formats * **model_name_or_path** - The hosted model name key or model path * **mini_batch_size** - The mini batch size for running inference * ****kwargs** - Keyword arguments for Flair's `SequenceTagger.predict()` method **return** - A list of Flair's `Sentence`'s """ # Load Sequence Tagger Model and Pytorch Module into tagger dict name = getattr(model_name_or_path, 'name', model_name_or_path) if not self.token_taggers[name]: """ self.token_taggers[model_name_or_path] = SequenceTagger.load( model_name_or_path ) """ if risinstance([FlairModelResult, HFModelResult], model_name_or_path): try: self.token_taggers[name] = FlairTokenTagger.load(name) except: self.token_taggers[name] = TransformersTokenTagger.load( name) elif risinstance([str, Path], model_name_or_path) and ( Path(model_name_or_path).exists() and Path(model_name_or_path).is_dir()): # Load in previously existing model try: self.token_taggers[name] = FlairTokenTagger.load(name) except: self.token_taggers[name] = TransformersTokenTagger.load( name) else: _flair_hub = FlairModelHub() _hf_hub = HFModelHub() res = _flair_hub.search_model_by_name(name, user_uploaded=True) if len(res) < 1: # No models found res = _hf_hub.search_model_by_name(name, user_uploaded=True) if len(res) < 1: logger.info("Not a valid `model_name_or_path` param") return [Sentence('')] else: res[0].name.replace('flairNLP', 'flair') self.token_taggers[ res[0].name] = TransformersTokenTagger.load( res[0].name) name = res[0].name else: name = res[0].name.replace('flairNLP/', '') self.token_taggers[name] = FlairTokenTagger.load( name ) # Returning the first should always be the non-fast option tagger = self.token_taggers[name] return tagger.predict( text=text, mini_batch_size=mini_batch_size, **kwargs, )
def tag_text( self, text: Union[List[Sentence], Sentence, List[str], str], model_name_or_path: Union[str, FlairModelResult, HFModelResult] = 'en-sentiment', mini_batch_size: int = 32, **kwargs, ) -> List[Sentence]: """Tags a text sequence with labels the sequence classification models have been trained on * **text** - String, list of strings, `Sentence`, or list of `Sentence`s to be classified * **model_name_or_path** - The model name key or model path * **mini_batch_size** - The mini batch size for running inference * ****kwargs** - (Optional) Keyword Arguments for Flair's `TextClassifier.predict()` method params **return** A list of Flair's `Sentence`'s """ # Load Text Classifier Model and Pytorch Module into tagger dict name = getattr(model_name_or_path, 'name', model_name_or_path) if not self.sequence_classifiers[name]: """ self.sequence_classifiers[name] = TextClassifier.load( model_name_or_path ) """ if risinstance([FlairModelResult, HFModelResult], model_name_or_path): try: self.sequence_classifiers[ name] = FlairSequenceClassifier.load(name) except: self.sequence_classifiers[ name] = TransformersSequenceClassifier.load(name) elif risinstance([str, Path], model_name_or_path) and ( Path(model_name_or_path).exists() and Path(model_name_or_path).is_dir()): # Load in previously existing model try: self.sequence_classifiers[ name] = FlairSequenceClassifier.load(name) except: self.sequence_classifiers[ name] = TransformersSequenceClassifier.load(name) else: # Flair res = self.flair_hub.search_model_by_name(name, user_uploaded=True) if len(res) < 1: # No models found res = self.hf_hub.search_model_by_name(model_name_or_path, user_uploaded=True) if len(res) < 1: logger.info("Not a valid `model_name_or_path` param") return [Sentence('')] else: name = res[0].name.replace('flairNLP', 'flair') self.sequence_classifiers[ res[0].name] = TransformersSequenceClassifier.load( name) else: name = res[0].name.replace('flairNLP/', '') self.sequence_classifiers[ name] = FlairSequenceClassifier.load( name ) # Returning the first should always be non-fast classifier = self.sequence_classifiers[name] return classifier.predict( text=text, mini_batch_size=mini_batch_size, **kwargs, )