示例#1
0
def data_initialization(args):
    data_stored_directory = args.data_stored_directory
    file = data_stored_directory + args.dataset_name + "_dataset.dset"
    if os.path.exists(file) and not args.refresh:
        data = load_data_setting(data_stored_directory, args.dataset_name)
    else:
        data = Data()
        data.dataset_name = args.dataset_name
        data.norm_char_emb = args.norm_char_emb
        data.norm_gaz_emb = args.norm_gaz_emb
        data.number_normalized = args.number_normalized
        data.max_sentence_length = args.max_sentence_length
        data.build_gaz_file(args.gaz_file)
        data.generate_instance(args.train_file, "train", False)
        data.generate_instance(args.dev_file, "dev")
        data.generate_instance(args.test_file, "test")
        data.build_char_pretrain_emb(args.char_embedding_path)
        data.build_gaz_pretrain_emb(args.gaz_file)
        data.fix_alphabet()
        data.get_tag_scheme()
        save_data_setting(data, data_stored_directory)
    return data
示例#2
0
        data.MAX_SENTENCE_LENGTH = conf_dict['MAX_SENTENCE_LENGTH']
        data.HP_lstm_layer = conf_dict['HP_lstm_layer']
        data_initialization(data, gaz_file, train_file, dev_file, test_file)

        if data.model_name in ['CNN_model', 'LSTM_model']:
            data.generate_instance_with_gaz_2(train_file, 'train')
            data.generate_instance_with_gaz_2(dev_file, 'dev')
            data.generate_instance_with_gaz_2(test_file, 'test')
        elif data.model_name in ['WC-LSTM_model']:
            data.generate_instance_with_gaz_3(train_file, 'train')
            data.generate_instance_with_gaz_3(dev_file, 'dev')
            data.generate_instance_with_gaz_3(test_file, 'test')
        else:
            print("model_name is not set!")
            sys.exit(1)
        data.build_char_pretrain_emb(char_emb)
        data.build_bichar_pretrain_emb(bichar_emb)
        data.build_gaz_pretrain_emb(gaz_file)
        train(data, save_model_dir, dset_dir, seg)
    elif status == 'test':
        data = load_data_setting(dset_dir)
        if data.model_name == 'CNN_model':
            data.generate_instance_with_gaz_2(test_file, 'test')
        elif data.model_name == 'LSTM_model':
            data.generate_instance_with_gaz(test_file, 'test')
        load_model_decode(model_dir, data, 'test', gpu, seg)
    else:
        print(
            "Invalid argument! Please use valid argumentguments! (train/test/decode)"
        )
示例#3
0
    saved_model_path = args.saved_model
    char_file = args.char_emb
    word_file = args.word_emb

    if status == 'train':
        if os.path.exists(saved_set_path):
            print('Loading saved data set...')
            with open(saved_set_path, 'rb') as f:
                data = pickle.load(f)
        else:
            data = Data()
            data_initialization(data, word_file, train_file, dev_file, test_file)
            data.generate_instance_with_words(train_file, 'train')
            data.generate_instance_with_words(dev_file, 'dev')
            data.generate_instance_with_words(test_file, 'test')
            data.build_char_pretrain_emb(char_file)
            data.build_word_pretrain_emb(word_file)
            if saved_set_path is not None:
                print('Dumping data...')
                with open(saved_set_path, 'wb') as f:
                    pickle.dump(data, f)
        data.show_data_summary()
        args.word_alphabet_size = data.word_alphabet.size()
        args.char_alphabet_size = data.char_alphabet.size()
        args.label_alphabet_size = data.label_alphabet.size()
        args.char_dim = data.char_emb_dim
        args.word_dim = data.word_emb_dim
        print_args(args)
        train(data, args, saved_model_path)

    elif status == 'test':