def __init__(self, num_classes=None, **kwargs): super().__init__(**kwargs) self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) model_state_dict = torch.load(self.local_paths[0], map_location=lambda storage, loc: storage) self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', state_dict=model_state_dict, num_labels=num_classes) self.model.to(self.device)
def load(self): self.device = get_device() self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) output_model_file = os.path.join(tempfile.gettempdir(), 'text_sentiment_pytorch_model.bin') download_file_from_google_drive(TEXT_SENTIMENT_FILE_ID, output_model_file) model_state_dict = torch.load(output_model_file, map_location=lambda storage, loc: storage) self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', state_dict=model_state_dict) self.model.to(self.device)
def load(self): self.device = get_device() self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) output_model_file = os.path.join(tempfile.gettempdir(), self.model_dir) download_file_from_google_drive(self.file_id, output_model_file) model_state_dict = torch.load( output_model_file, map_location=lambda storage, loc: storage) self.model = BertForSequenceClassification.from_pretrained( 'bert-base-uncased', state_dict=model_state_dict, num_labels=self.num_classes) self.model.to(self.device)
def fit(self, x, y, time_limit=None): self.num_labels = len(list(set(y))) # Prepare model model = BertForSequenceClassification.from_pretrained( self.bert_model, cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_-1', num_labels=self.num_labels) all_input_ids, all_input_mask, all_segment_ids = self.preprocess(x) all_label_ids = torch.tensor([int(f) for f in y], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) bert_trainer = BERTTrainer(train_data, model, self.output_model_file, self.num_labels) bert_trainer.train_model()
def predict(self, x_test): """ Predict the labels for the provided input data. Args: x_test: ndarray containing the test data inputs. Returns: ndarray containing the predicted labels/outputs for x_test. """ # Load a trained model that you have fine-tuned model_state_dict = torch.load(self.output_model_file) model = BertForSequenceClassification.from_pretrained( self.bert_model, state_dict=model_state_dict, num_labels=self.num_labels) model.to(self.device) if self.verbose: print("***** Running evaluation *****") print(" Num examples = %d", len(x_test)) print(" Batch size = %d", self.eval_batch_size) all_input_ids, all_input_mask, all_segment_ids = self.preprocess( x_test) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids) # Run prediction for full data eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=self.eval_batch_size) model.eval() y_preds = [] for input_ids, input_mask, segment_ids in eval_dataloader: input_ids = input_ids.to(self.device) input_mask = input_mask.to(self.device) segment_ids = segment_ids.to(self.device) with torch.no_grad(): logits = model(input_ids, segment_ids, input_mask) logits = logits.detach().cpu().numpy() y_preds.extend(logits) return self.inverse_transform_y(y_preds)
def fit(self, x, y, time_limit=None): """ Train the text classifier based on the training data. Args: x: ndarray containing the train data inputs. y: ndarray containing the train data outputs/labels. time_limit: Maximum time allowed for searching. It does not apply for this classifier. """ self.num_labels = len(list(set(y))) # Prepare model model = BertForSequenceClassification.from_pretrained(self.bert_model, cache_dir=PYTORCH_PRETRAINED_BERT_CACHE/'distributed_-1', num_labels=self.num_labels) all_input_ids, all_input_mask, all_segment_ids = self.preprocess(x) all_label_ids = torch.tensor([int(f) for f in y], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) bert_trainer = BERTTrainer(train_data, model, self.output_model_file, self.num_labels) bert_trainer.train_model()