def __init__( self, task: str = None, load_checkpoint: str = None, label_map: Dict = None, num_classes: int = 2, **kwargs, ): super(Bert, self).__init__() if label_map: self.label_map = label_map self.num_classes = len(label_map) else: self.num_classes = num_classes if task == 'sequence_classification': task = 'seq-cls' logger.warning( "current task name 'sequence_classification' was renamed to 'seq-cls', " "'sequence_classification' has been deprecated and will be removed in the future.", ) if task == 'seq-cls': self.model = BertForSequenceClassification.from_pretrained( pretrained_model_name_or_path='bert-base-uncased', num_classes=self.num_classes, **kwargs) self.criterion = paddle.nn.loss.CrossEntropyLoss() self.metric = paddle.metric.Accuracy() elif task == 'token-cls': self.model = BertForTokenClassification.from_pretrained( pretrained_model_name_or_path='bert-base-uncased', num_classes=self.num_classes, **kwargs) self.criterion = paddle.nn.loss.CrossEntropyLoss() self.metric = ChunkEvaluator(label_list=[ self.label_map[i] for i in sorted(self.label_map.keys()) ]) elif task == 'text-matching': self.model = BertModel.from_pretrained( pretrained_model_name_or_path='bert-base-uncased', **kwargs) self.dropout = paddle.nn.Dropout(0.1) self.classifier = paddle.nn.Linear( self.model.config['hidden_size'] * 3, 2) self.criterion = paddle.nn.loss.CrossEntropyLoss() self.metric = paddle.metric.Accuracy() elif task is None: self.model = BertModel.from_pretrained( pretrained_model_name_or_path='bert-base-uncased', **kwargs) else: raise RuntimeError( "Unknown task {}, task should be one in {}".format( task, self._tasks_supported)) self.task = task if load_checkpoint is not None and os.path.isfile(load_checkpoint): state_dict = paddle.load(load_checkpoint) self.set_state_dict(state_dict) logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
def __init__(self, config: SimpleConfig): model = BertModel.from_pretrained(config.pretrained_model) super().__init__(model, config.num_labels, config.dropout)