def get_ncrf_data_object(model_name): #, input_path, output_path): data = Data() model = MODEL_PATHS[model_name] data.dset_dir = model['dset'] data.load(data.dset_dir) data.HP_gpu = False #data.raw_dir = input_path #data.decode_dir = output_path data.load_model_dir = model['model'] data.nbest = None return data
parser.add_argument('--savedset', help='Dir of saved data setting') parser.add_argument('--train', default="data/conll03/train.bmes") parser.add_argument('--dev', default="data/conll03/dev.bmes") parser.add_argument('--test', default="data/conll03/test.bmes") parser.add_argument('--seg', default="True") parser.add_argument('--raw') parser.add_argument('--loadmodel') parser.add_argument('--output') args = parser.parse_args() data = Data() data.train_dir = args.train data.dev_dir = args.dev data.test_dir = args.test data.model_dir = args.savemodel data.dset_dir = args.savedset print("aaa", data.dset_dir) status = args.status.lower() save_model_dir = args.savemodel data.HP_gpu = torch.cuda.is_available() print("Seed num:", seed_num) data.number_normalized = True data.word_emb_dir = "../data/glove.6B.100d.txt" if status == 'train': print("MODEL: train") data_initialization(data) data.use_char = True data.HP_batch_size = 10 data.HP_lr = 0.015 data.char_seq_feature = "CNN"
def main(): parser = argparse.ArgumentParser(description='Tuning with NCRF++') # parser.add_argument('--status', choices=['train', 'decode'], help='update algorithm', default='train') parser.add_argument('--config', help='Configuration File', default='None') parser.add_argument('--wordemb', help='Embedding for words', default='None') parser.add_argument('--charemb', help='Embedding for chars', default='None') parser.add_argument('--status', choices=['train', 'decode'], help='update algorithm', default='train') parser.add_argument('--savemodel', default="data/model/saved_model.lstmcrf.") parser.add_argument('--savedset', help='Dir of saved data setting') parser.add_argument('--train', default="data/conll03/train.bmes") parser.add_argument('--dev', default="data/conll03/dev.bmes") parser.add_argument('--test', default="data/conll03/test.bmes") parser.add_argument('--seg', default="True") parser.add_argument('--random-seed', type=int, default=42) parser.add_argument('--lr', type=float) parser.add_argument('--batch-size', type=int) parser.add_argument('--raw') parser.add_argument('--loadmodel') parser.add_argument('--output') parser.add_argument('--output-tsv') parser.add_argument('--model-prefix') parser.add_argument('--cpu', action='store_true') args = parser.parse_args() # Set random seed seed_num = args.random_seed random.seed(seed_num) torch.manual_seed(seed_num) np.random.seed(seed_num) data = Data() data.random_seed = seed_num data.HP_gpu = torch.cuda.is_available() if args.config == 'None': data.train_dir = args.train data.dev_dir = args.dev data.test_dir = args.test data.model_dir = args.savemodel data.dset_dir = args.savedset print("Save dset directory:", data.dset_dir) save_model_dir = args.savemodel data.word_emb_dir = args.wordemb data.char_emb_dir = args.charemb if args.seg.lower() == 'true': data.seg = True else: data.seg = False print("Seed num:", seed_num) else: data.read_config(args.config) if args.lr: data.HP_lr = args.lr if args.batch_size: data.HP_batch_size = args.batch_size data.output_tsv_path = args.output_tsv if args.cpu: data.HP_gpu = False if args.model_prefix: data.model_dir = args.model_prefix # data.show_data_summary() status = data.status.lower() print("Seed num:", seed_num) if status == 'train': print("MODEL: train") data_initialization(data) data.generate_instance('train') data.generate_instance('dev') data.generate_instance('test') data.build_pretrain_emb() train(data) elif status == 'decode': print("MODEL: decode") data.load(data.dset_dir) data.read_config(args.config) print(data.raw_dir) # exit(0) data.show_data_summary() data.generate_instance('raw') print("nbest: %s" % (data.nbest)) decode_results, pred_scores = load_model_decode(data, 'raw') if data.nbest and not data.sentence_classification: data.write_nbest_decoded_results(decode_results, pred_scores, 'raw') else: data.write_decoded_results(decode_results, 'raw') else: print( "Invalid argument! Please use valid arguments! (train/test/decode)" )
parser.add_argument('--savedset', help='Dir of saved data setting') parser.add_argument('--train', default="data/conll03/train.bmes") parser.add_argument('--dev', default="data/conll03/dev.bmes" ) parser.add_argument('--test', default="data/conll03/test.bmes") parser.add_argument('--seg', default="True") parser.add_argument('--raw') parser.add_argument('--loadmodel') parser.add_argument('--output') args = parser.parse_args() data = Data() data.train_dir = args.train data.dev_dir = args.dev data.test_dir = args.test data.model_dir = args.savemodel data.dset_dir = args.savedset print("aaa",data.dset_dir) status = args.status.lower() save_model_dir = args.savemodel data.HP_gpu = torch.cuda.is_available() print("Seed num:",seed_num) data.number_normalized = True data.word_emb_dir = "../data/glove.6B.100d.txt" if status == 'train': print("MODEL: train") data_initialization(data) data.use_char = True data.HP_batch_size = 10 data.HP_lr = 0.015 data.char_seq_feature = "CNN"