def train():

	# 加载数据集

	train_sentences = data_loader.load_sentences(FLAGS.train_file)

	dev_sentences = data_loader.load_sentences(FLAGS.dev_file)

	test_sentences = data_loader.load_sentences(FLAGS.test_file)



	# 转换编码 bio转bioes

	data_loader.update_tag_scheme(train_sentences, FLAGS.tag_schema)

	data_loader.update_tag_scheme(test_sentences, FLAGS.tag_schema)

	data_loader.update_tag_scheme(dev_sentences, FLAGS.tag_schema)
示例#2
0
def train():
    # 加载数据集
    train_sentences = dl.load_sentences(FLAGS.train_file)
    dev_sentences = dl.load_sentences(FLAGS.dev_file)
    test_sentences = dl.load_sentences(FLAGS.test_file)

    # 转换编码 bio转bioes
    dl.update_tag_scheme(train_sentences, FLAGS.tag_schema)
    dl.update_tag_scheme(test_sentences, FLAGS.tag_schema)
    dl.update_tag_scheme(dev_sentences, FLAGS.tag_schema)

    # 创建单词映射及标签映射
    if not os.path.isfile(FLAGS.map_file):
        _, word_to_id, id_to_word = dl.word_mapping(train_sentences)
        _, tag_to_id, id_to_tag = dl.tag_mapping(train_sentences)

        with open(FLAGS.map_file, 'wb') as f:
            pickle.dump([word_to_id, id_to_word, tag_to_id, id_to_tag], f)
    else:
        with open(FLAGS.map_file, 'rb') as f:
            unpickler = pickle.Unpickler(f)
            scores = unpickler.load()
            word_to_id, id_to_word, tag_to_id, id_to_tag = scores

    train_data = dl.prapare_dataset(train_sentences, word_to_id, tag_to_id)
    dev_data = dl.prapare_dataset(train_sentences, word_to_id, tag_to_id)
    test_data = dl.prapare_dataset(train_sentences, word_to_id, tag_to_id)

    print('train_data %i, dev_data_num %i, test_data_num %i' %
          (len(train_data), len(dev_data), len(test_data)))

    mu.make_path(FLAGS)
    if os.path.isfile(FLAGS.config_file):
        config = mu.load_config(FLAGS.config_file)
    else:
        config = mu.config_model(FLAGS, word_to_id, tag_to_id)
        mu.save_config(config, FLAGS.config_file)
    log_path = os.path.join('log', FLAGS.log_file)
    logger = mu.get_log(log_path)
    mu.print_config(config, logger)
    print('aa')
示例#3
0
def train():
    # 1、加载数据集
    train_sentences = data_loader.load_sentences(FLAGS.train_file)
    dev_sentences = data_loader.load_sentences(FLAGS.dev_file)
    test_sentences = data_loader.load_sentences(FLAGS.test_file)

    # 2、转换编码 BIO->BIOES
    data_loader.update_tag_scheme(train_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(dev_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(test_sentences, FLAGS.tag_schema)

    # 3、创建单词映射与标签映射
    if not os.path.isfile(FLAGS.map_file):
        _, word_to_id, id_to_word = data_loader.word_mapping(train_sentences)
        _, tag_to_id, id_to_tag = data_loader.tag_mapping(train_sentences)

        with open(FLAGS.map_file, "wb") as f:
            # 序列化pickle.dump(obj, file, [,protocol]),,序列化对象,将对象obj保存到文件file中去。
            pickle.dump([word_to_id, id_to_word, tag_to_id, id_to_tag], f)
    else:
        # 反序列化对象,将文件中的数据解析为一个python对象。file中有read()接口和readline()接口
        with open(FLAGS.map_file, "rb") as f:
            word_to_id, id_to_word, tag_to_id, id_to_tag = pickle.load(f)

    # 4、数据预处理
    train_data = data_loader.prepare_dataset(train_sentences, word_to_id,
                                             tag_to_id)

    dev_data = data_loader.prepare_dataset(dev_sentences, word_to_id,
                                           tag_to_id)

    test_data = data_loader.prepare_dataset(test_sentences, word_to_id,
                                            tag_to_id)

    model_utils.make_path(FLAGS)

    config = model_utils.config_model(FLAGS, word_to_id, tag_to_id)
示例#4
0
def train():
    # 加载数据集
    train_sentences = data_loader.load_sentences(FLAGS.train_file)
    dev_sentences = data_loader.load_sentences(FLAGS.dev_file)
    test_sentences = data_loader.load_sentences(FLAGS.test_file)

    # 转换编码
    data_loader.update_tag_scheme(train_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(dev_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(test_sentences, FLAGS.tag_schema)

    # 创建单词和词典映射
    if not os.path.isfile(FLAGS.map_file):
        if FLAGS.pre_emb:
            dico_words_train = data_loader.word_mapping(train_sentences)[0]
            dico_word, word_to_id, id_to_word = data_utils.augment_with_pretrained(
                dico_words_train.copy(), FLAGS.emb_file,
                list(
                    itertools.chain.from_iterable([[w[0] for w in s]
                                                   for s in test_sentences])))
        else:
            _, word_to_id, id_to_word = data_loader.word_mapping(
                train_sentences)
        _, tag_to_id, id_to_tag = data_loader.tag_mapping(train_sentences)
        with open(FLAGS.map_file, 'wb') as f:
            pickle.dump([word_to_id, id_to_word, tag_to_id, id_to_tag], f)
    else:
        with open(FLAGS.map_file, 'rb') as f:
            word_to_id, id_to_word, tag_to_id, id_to_tag = pickle.load(f)

    # 准备数据
    train_data = data_loader.prepare_dataset(train_sentences, word_to_id,
                                             tag_to_id)
    dev_data = data_loader.prepare_dataset(dev_sentences, word_to_id,
                                           tag_to_id)
    test_data = data_loader.prepare_dataset(test_sentences, word_to_id,
                                            tag_to_id)

    # 将数据分批处理
    train_manager = data_utils.BatchManager(train_data, FLAGS.batch_size)
    dev_manager = data_utils.BatchManager(dev_data, FLAGS.batch_size)
    test_manager = data_utils.BatchManager(test_data, FLAGS.batch_size)

    # 创建不存在的文件夹
    model_utils.make_path(FLAGS)

    # 判断配置文件
    if os.path.isfile(FLAGS.config_file):
        config = model_utils.load_config(FLAGS.config_file)
    else:
        config = model_utils.config_model(FLAGS, word_to_id, tag_to_id)
        model_utils.save_config(config, FLAGS.config_file)

    # 配置印logger
    log_path = os.path.join('log', FLAGS.log_file)
    logger = model_utils.get_logger(log_path)
    model_utils.print_config(config, logger)

    tf_config = tf.ConfigProto(allow_soft_placement=True)
    tf_config.gpu_options.allow_growth = True

    step_per_epoch = train_manager.len_data
    with tf.Session(config=tf_config) as sess:
        model = model_utils.create(sess, Model, FLAGS.ckpt_path, load_word2vec,
                                   config, id_to_word, logger)
        logger.info('开始训练')
        loss = []
        start = time.time()
        for i in range(100):
            for batch in train_manager.iter_batch(shuffle=True):
                step, batch_loss = model.run_step(sess, True, batch)
                loss.append(batch_loss)
                if step % FLAGS.setps_chech == 0:
                    iteration = step // step_per_epoch + 1
                    logger.info(
                        "iteration{}: step{}/{}, NER loss:{:>9.6f}".format(
                            iteration, step % step_per_epoch, step_per_epoch,
                            np.mean(loss)))
                    loss = []
            best = evaluate(sess, model, 'dev', dev_manager, id_to_tag, logger)

            if best:
                model_utils.save_model(sess, model, FLAGS.ckpt_path, logger)
            evaluate(sess, model, 'test', test_manager, id_to_tag, logger)
        t = time.time() - start
        logger.info('cost time: %f' % t)
示例#5
0
def train():
    # 加载数据集
    train_sentences = data_loader.load_sentences(FLAGS.train_file)
    dev_sentences = data_loader.load_sentences(FLAGS.dev_file)
    test_sentences = data_loader.load_sentences(FLAGS.test_file)

    # 转换编码 bio转bioes
    data_loader.update_tag_scheme(train_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(test_sentences, FLAGS.tag_schema)
    data_loader.update_tag_scheme(dev_sentences, FLAGS.tag_schema)

    # 创建单词映射及标签映射
    if not os.path.isfile(FLAGS.map_file):
        if FLAGS.pre_emb:
            dico_words_train = data_loader.word_mapping(train_sentences)[0]
            dico_word, word_to_id, id_to_word = data_utils.augment_with_pretrained(
                dico_words_train.copy(),
                FLAGS.emb_file,
                list(
                    itertools.chain.from_iterable(
                        [[w[0] for w in s] for s in test_sentences]
                    )
                )
            )
        else:
            _, word_to_id, id_to_word = data_loader.word_mapping(train_sentences)

        _, tag_to_id, id_to_tag = data_loader.tag_mapping(train_sentences)

        with open(FLAGS.map_file, "wb") as f:
            pickle.dump([word_to_id, id_to_word, tag_to_id, id_to_tag], f)
    else:
        with open(FLAGS.map_file, 'rb') as f:
            word_to_id, id_to_word, tag_to_id, id_to_tag = pickle.load(f)

    train_data = data_loader.prepare_dataset(
        train_sentences, word_to_id, tag_to_id
    )

    dev_data = data_loader.prepare_dataset(
        dev_sentences, word_to_id, tag_to_id
    )

    test_data = data_loader.prepare_dataset(
        test_sentences, word_to_id, tag_to_id
    )

    train_manager = data_utils.BatchManager(train_data, FLAGS.batch_size)
    dev_manager = data_utils.BatchManager(dev_data, FLAGS.batch_size)
    test_manager = data_utils.BatchManager(test_data, FLAGS.batch_size)

    print('train_data_num %i, dev_data_num %i, test_data_num %i' % (len(train_data), len(dev_data), len(test_data)))

    model_utils.make_path(FLAGS)

    if os.path.isfile(FLAGS.config_file):
        config = model_utils.load_config(FLAGS.config_file)
    else:
        config = model_utils.config_model(FLAGS, word_to_id, tag_to_id)
        model_utils.save_config(config, FLAGS.config_file)

    log_path = os.path.join("log", FLAGS.log_file)
    logger = model_utils.get_logger(log_path)
    model_utils.print_config(config, logger)

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    steps_per_epoch =train_manager.len_data
    with tf.Session(config = tf_config) as sess:
        model = model_utils.create(sess, Model, FLAGS.ckpt_path, load_word2vec, config, id_to_word, logger)
        logger.info("开始训练")
        loss = []
        for i in range(100):
            for batch in train_manager.iter_batch(shuffle=True):
                step, batch_loss = model.run_step(sess, True, batch)
                loss.append(batch_loss)
                if step % FLAGS.setps_chech== 0:
                    iterstion = step // steps_per_epoch + 1
                    logger.info("iteration:{} step{}/{},NER loss:{:>9.6f}".format(iterstion, step%steps_per_epoch, steps_per_epoch, np.mean(loss)))
                    loss = []

            best = evaluate(sess,model,"dev", dev_manager, id_to_tag, logger)

            if best:
                model_utils.save_model(sess, model, FLAGS.ckpt_path, logger)
            evaluate(sess, model, "test", test_manager, id_to_tag, logger)
示例#6
0
def main(args):

    lossFunction = nn.CrossEntropyLoss()

    hidden_size = args.hidden_size
    epsilon = args.epsilon
    training_mode = args.training_mode
    learning_rate = args.learning_rate

    epochs = args.epochs
    K = args.K_shot_learning
    N = args.N_way_learning
    inner_epoch = args.inner_gradient_update
    max_len = 116

    sanskrit_train, marathi_train, marathi_test, marathi_dev, hindi_train, bhojpuri_train, magahi_train, english_train, german_train, dutch_train, danish_train = load_sentences(
    )
    tokens_dict, dict_token, n_tokens = get_tokens(bhojpuri_train)

    marathi, marathi_tags = get_sentences(marathi_train, None, tokens_dict,
                                          max_len)
    marathi_d, marathi_tags_d = get_sentences(marathi_dev, None, tokens_dict,
                                              max_len)
    marathi_t, marathi_tags_t = get_sentences(marathi_test, None, tokens_dict,
                                              max_len)
    hindi, hindi_tags = get_sentences(hindi_train, None, tokens_dict, max_len)
    bhojpuri, bhojpuri_tags = get_sentences(bhojpuri_train, None, tokens_dict,
                                            max_len)
    magahi, magahi_tags = get_sentences(magahi_train, None, tokens_dict,
                                        max_len)
    sanskrit, sanskrit_tags = get_sentences(sanskrit_train, None, tokens_dict,
                                            max_len)
    marathi = marathi + marathi_d + marathi_t
    marathi_tags = marathi_tags + marathi_tags_d + marathi_tags_t

    english, english_tags = get_sentences(dutch_train, None, tokens_dict,
                                          max_len)
    dutch, dutch_tags = get_sentences(dutch_train, None, tokens_dict, max_len)
    danish, danish_tags = get_sentences(danish_train, None, tokens_dict,
                                        max_len)
    german, german_tags = get_sentences(german_train, None, tokens_dict,
                                        max_len)

    model_hindi = gs.Word2Vec(hindi, min_count=1, size=hidden_size)
    model_marathi = gs.Word2Vec(marathi, min_count=1, size=hidden_size)
    model_sanskrit = gs.Word2Vec(sanskrit, min_count=1, size=hidden_size)
    model_bhojpuri = gs.Word2Vec(bhojpuri, min_count=1, size=hidden_size)
    model_magahi = gs.Word2Vec(magahi, min_count=1, size=hidden_size)
    model_german = gs.Word2Vec(german, min_count=1, size=hidden_size)
    model_english = gs.Word2Vec(english, min_count=1, size=hidden_size)
    model_dutch = gs.Word2Vec(dutch, min_count=1, size=hidden_size)
    model_danish = gs.Word2Vec(danish, min_count=1, size=hidden_size)

    char_dict, n_chars = get_characters(marathi + hindi + bhojpuri + sanskrit +
                                        magahi + english + dutch + danish +
                                        german)

    hindi_data_loader = DataLoader(hindi, None, hindi_tags, None, max_len,
                                   model_hindi)
    marathi_data_loader = DataLoader(marathi, None, marathi_tags, None,
                                     max_len, model_marathi)
    sanskrit_data_loader = DataLoader(sanskrit, None, sanskrit_tags, None,
                                      max_len, model_sanskrit)
    bhojpuri_data_loader = DataLoader(bhojpuri, None, bhojpuri_tags, None,
                                      max_len, model_bhojpuri)
    magahi_data_loader = DataLoader(magahi, None, magahi_tags, None, max_len,
                                    model_magahi)
    english_data_loader = DataLoader(english, None, english_tags, None,
                                     max_len, model_english)
    german_data_loader = DataLoader(german, None, german_tags, None, max_len,
                                    model_german)
    danish_data_loader = DataLoader(danish, None, danish_tags, None, max_len,
                                    model_danish)
    dutch_data_loader = DataLoader(dutch, None, dutch_tags, None, max_len,
                                   model_dutch)

    metaLearn = MetaLearn(hindi_data_loader, marathi_data_loader,
                          sanskrit_data_loader, bhojpuri_data_loader,
                          magahi_data_loader, english_data_loader,
                          german_data_loader, dutch_data_loader,
                          danish_data_loader, lossFunction, hidden_size,
                          epochs, inner_epoch, max_len, n_tokens, tokens_dict,
                          dict_token, char_dict, n_chars, N, K, learning_rate)

    if args.resume_training:
        model = torch.load(args.checkpoint_path)
        metaLearn.epochs = model['epoch']
        metaLearn.load_state_dict(model['model'])

        if args.resume_training_type == 'MAML':
            metaLearn.train()
            _ = metaLearn.test()
        elif args.resume_training_type == 'Reptile':
            metaLearn.train_Reptile()
            _ = metaLearn.test()

    elif args.load_model:
        metaLearn.load_state_dict(torch.load(args.model_path))
        _ = metaLearn.test()

    if training_mode == 'MAML':
        metaLearn.train_MAML()
        _ = metaLearn.test()
    elif training_mode == 'Reptile':
        metaLearn.train_Reptile(epsilon)
        _ = metaLearn.test()
    else:
        raise (NotImplementedError('This algorithm has not been implemented'))