def predict(self, input): prediction = list() # Load lables with open('./output_artefacts/{LABLE_TO_IX.JSON}') as json_file: label_to_ix = json.load(json_file) # Load config with open('./output_artefacts/{CONFIG.PKL}', 'rb') as f: config = pickle.load(f) config.num_labels = len(list(label_to_ix.values())) model = DistilBertForSequenceClassification(config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = model.cuda() model_path = './output_artefacts/{MODEL.PTH}' model.load_state_dict(torch.load(model_path, map_location=device)) msg = preprocessing(str(input)) model.eval() input_msg, _ = prepare_features(msg) if torch.cuda.is_available(): input_msg = input_msg.cuda() output = model(input_msg)[0] output_exp = torch.exp(output) probability = torch.div(output_exp, torch.add(output_exp, 1.)) _, pred_label = probability.topk(1) percent = torch.mul(_, 100) percent = torch.reshape(percent, (1, 1)).tolist() percent = [item for sublist in percent for item in sublist] for i in pred_label[0]: queue = list(label_to_ix.keys())[i] prediction.append(queue) my_result = dict(zip(prediction, percent)) return my_result
def create_and_check_distilbert_for_sequence_classification( self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels model = DistilBertForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_distilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): config.num_labels = self.num_labels model = DistilBertForSequenceClassification(config) model.eval() loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels) result = { "loss": loss, "logits": logits, } self.parent.assertListEqual( list(result["logits"].size()), [self.batch_size, self.num_labels]) self.check_loss_output(result)