def _evaluate_for_train_valid(self): """Evaluate model on train and valid set and get acc and f1 score. Returns: train_acc, train_f1, valid_acc, valid_f1 """ train_predictions = evaluate( model=self.model, data_loader=self.data_loader['valid_train'], device=self.device) valid_predictions = evaluate( model=self.model, data_loader=self.data_loader['valid_valid'], device=self.device) train_answers = get_labels_from_file(self.config.train_file_path) valid_answers = get_labels_from_file(self.config.valid_file_path) train_acc, train_f1 = calculate_accuracy_f1(train_answers, train_predictions) valid_acc, valid_f1 = calculate_accuracy_f1(valid_answers, valid_predictions) return train_acc, train_f1, valid_acc, valid_f1
def main(out_file='output/result.txt', model_config='config/rnn_config.json'): """Test model for given test set on 1 GPU or CPU. Args: in_file: file to be tested out_file: output file model_config: config file """ # 0. Load config with open(model_config) as fin: config = json.load(fin, object_hook=lambda d: SimpleNamespace(**d)) if torch.cuda.is_available(): device = torch.device('cuda') # device = torch.device('cpu') else: device = torch.device('cpu') #0. preprocess file # id_list = [] # with open(in_file, 'r', encoding='utf-8') as fin: # for line in fin: # sents = json.loads(line.strip()) # id = sents['id'] # id_list.append(id) # id_dict = dict(zip(range(len(id_list)), id_list)) # 1. Load data data = Data(vocab_file=os.path.join(config.model_path, 'vocab.txt'), max_seq_len=config.max_seq_len, model_type=config.model_type, config=config) test_set, sc_list, label_list = data.load_file(config.test_file_path, train=True) token_list = [] for line in sc_list: tokens = data.tokenizer.convert_ids_to_tokens(line) token_list.append(tokens) data_loader_test = DataLoader(test_set, batch_size=config.batch_size, shuffle=False) # 2. Load model model = MODEL_MAP[config.model_type](config) model = load_torch_model(model, model_path=os.path.join(config.model_path, 'model.bin')) model.to(device) # 3. Evaluate answer_list, length_list = evaluate(model, data_loader_test, device, isTest=False) def flatten(ll): return list(itertools.chain(*ll)) train_answers = handy_tool(label_list, length_list) #gold #answer_list = handy_tool(answer_list, length_list) #prediction train_answers = flatten(train_answers) train_predictions = flatten(answer_list) train_acc, train_f1 = calculate_accuracy_f1(train_answers, train_predictions) print(train_acc, train_f1) mod_tokens_list = handy_tool(token_list, length_list) result = [ result_to_json(t, s) for t, s in zip(mod_tokens_list, answer_list) ] # 4. Write answers to file with open(out_file, 'w', encoding='utf8') as fout: for item in result: entities = item['entities'] words = [d['word'] for d in entities] fout.write(" ".join(words) + "\n")