def bilstm_crf_train_eval(config, import_model, train_examples, dev_examples=None, test_examples=None): processor = DataProcessor(config.data_dir, config.do_lower_case) word2id = processor.get_vocab() config.vocab_size = len(word2id) train_features = convert_examples_to_features_crf( examples=train_examples, word2id=word2id, label_list=config.label_list, ) train_dataset = BuildDataSet(train_features) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) if dev_examples: dev_features = convert_examples_to_features_crf( examples=dev_examples, word2id=word2id, label_list=config.label_list, ) dev_dataset = BuildDataSet(dev_features) dev_loader = DataLoader(dev_dataset, batch_size=config.batch_size, shuffle=True) else: dev_loader = None if test_examples: test_features = convert_examples_to_features_crf( examples=test_examples, word2id=word2id, label_list=config.label_list, ) test_dataset = BuildDataSet(test_features) test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) else: test_loader = None logger.info("self config:\n {}".format(config_to_json_string(config))) model = import_model.Model(config).to(config.device) best_model = model_train(config, model, train_loader, dev_loader, to_crf=True) model_test(config, best_model, test_loader, to_crf=True)
def Task(config): if config.device.type == 'cuda': torch.cuda.set_device(config.device_id) tokenizer = BertTokenizer.from_pretrained( config.tokenizer_file, do_lower_case=config.do_lower_case) processor = DataProcessor(data_dir=config.data_dir, do_lower_case=config.do_lower_case, language=config.language, do_preprocessing=config.do_preprocessing, split=config.split) config.class_list = processor.get_labels() config.num_labels = len(config.class_list) config.speaker_dict = processor.get_speaker_map( config.speaker_threshold) if config.speaker_tag else None train_examples = processor.get_train_examples() dev_examples = processor.get_dev_examples() test_examples = processor.get_test_examples(config.test_file) cur_model = MODEL_CLASSES[config.use_model] model = cur_model(config) logger.info("self config: {}\n".format(config_to_json_string(config))) model_example, metrics_result, predict_label = cross_validation( config=config, train_examples=train_examples, dev_examples=dev_examples, model=model, tokenizer=tokenizer, pattern=config.pattern, test_examples=test_examples) if config.pattern == 'k_fold': logger.info('K({})-fold models dev acc: {}'.format( config.k_fold, metrics_result[0])) logger.info('K({})-fold models dev f1: {}'.format( config.k_fold, metrics_result[1])) dev_acc = np.array(metrics_result[0]).mean() dev_f1 = np.array(metrics_result[1]).mean() predict_label = combined_result(predict_label, pattern='average') else: dev_acc = metrics_result[0] dev_f1 = metrics_result[1] logger.info("dev evaluate average Acc: {}, F1:{}".format(dev_acc, dev_f1)) file_name = '{}_{}_{:>.6f}.csv'.format( strftime("%m%d-%H%M%S", localtime()), config.language, dev_f1) predict_to_save(predict_label, path=config.result_save_path, file=file_name, prob_threshold=config.prob_threshold)