コード例 #1
0
def pro_data():
    '''读取验证集中数据
    :return dev_data 验证集数据列表
    :return dev_label 验证集列表中数据对应label
    '''
    pro = process.process_data()
    pro.split_data_file(normal_param.dev_path)
    dev_data, dev_label, dev_label_not_deal = pro.deal_data(part=len(pro.all_text_path), n_part=0, need_label= True)
    return dev_data, dev_label, dev_label_not_deal
コード例 #2
0
ファイル: LSI_.py プロジェクト: pdaicode/deal_contact
def load_corpus():
    process_data_init = process_data.process_data()
    process_data_init.split_data_file(normal_param.train_path)
    # process_data_init.deal_data(part= 0, n_part= process_data_init.all_text_path)
    x_texts_labels = process_data_init.build_datas_and_labels(
        process_data_init.all_text_path)
    # 是取文本中对应的词语在该文本中的对应下标
    x_texts = [x.split(" ") for x, label in x_texts_labels]
    # list_arr_text = self.deal_text(x_texts)
    labels = [label for x, label in x_texts_labels]
    return x_texts, labels
コード例 #3
0
def pro_data(texts_path):
    '''读取测试路径中数据(texts_path)里面的待测试txt文件
    :return contents 所有待测试文件内容列表
    '''
    pro = process.process_data()
    # pro.split_data_file(normal_param.dev_path)
    contents_and_label = []
    contents = []
    for text_path in texts_path:
        content = pro.load_data(text_path)
        contents += [content]
        # contents += [contents_and_label]
    # dev_data, dev_label, _=    pro.deal_data(part=len(pro.all_text_path), n_part=0)

    return contents
コード例 #4
0
ファイル: train.py プロジェクト: aimasa/contact_classify
def train():
    print("loading data……")
    process_data_init = process_data()
    process_data_init.split_data_file(normal_param.train_path)
    # train_data, train_label, dev_data, dev_label, vocal_size_train = process_data_init.deal_data(
    #     normal_param.train_path, n_part=0)

    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(allow_soft_placement=normal_param.allow_soft_placement,
                                      log_device_placement=normal_param.log_device_placement)
        # session_conf.gpu_options.allow_growth = True
        sess = tf.Session(config=session_conf)
        # KTF.set_session(sess)
        with sess.as_default():
            cnn_init = cnn.model_cnn(sequence_length=normal_param.sequence_length, num_classes=normal_param.num_class,
                                     vocab_size=normal_param.vocal_size_train,
                                     embedding_size=normal_param.embedding_dim,
                                     filters_size=list(map(int, normal_param.filter_sizes.split(","))),
                                     num_filters=normal_param.num_filters)
            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(cnn_init.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
            time_step = str(int(time.time()))
            print(time.time())
            out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", normal_param.model_name))
            print("Writing to {}\n".format(out_dir))

            loss_summary = tf.summary.scalar("loss", cnn_init.loss)
            acc_summary = tf.summary.scalar("accuracy", cnn_init.accuracy)

            train_summary_op = tf.summary.merge([loss_summary, acc_summary])
            train_summary_dir = os.path.join(out_dir, "summary", "train")
            train_summary_write = tf.summary.FileWriter(train_summary_dir, sess.graph_def)

            dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
            dev_summary_dir = os.path.join(out_dir, "summary", "dev")
            dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph_def)
            sess.run(tf.global_variables_initializer())
            # sess.run(tf.local_variables_initializer())

            max_acc = 0

            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoint"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver = tf.train.Saver(tf.global_variables())

            ckpt = tf.train.latest_checkpoint(checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print("CNN restore from the checkpoint {0}".format(ckpt))
                # current_step = int(ckpt.split('-')[-1])


            # sess.run(tf.initializers)

            def train_step(x_bratch, y_bratch, writer):

                # print(y_bratch.shape())
                feed_dic = {
                    cnn_init.input_x: x_bratch,
                    cnn_init.input_y: y_bratch,
                    cnn_init.dropout_keep_prob: normal_param.dropout_keep_prob

                }
                _, step, summaries, loss, accuracy = sess.run(
                    [train_op, global_step, train_summary_op, cnn_init.loss, cnn_init.accuracy], feed_dic)
                # print("scores: ", score)
                time_str = datetime.datetime.now().isoformat()
                print('{}: step {}, loss {:g}, acc {:g}'.format(
                    time_str, step, loss, accuracy))
                if writer:
                    writer.add_summary(summaries, step)



            def dev_step(x_bratch, y_bratch, writer=None):
                '''在开发集上验证数据集'''
                feed_dic = {
                    cnn_init.input_x: x_bratch,
                    cnn_init.input_y: y_bratch,
                    cnn_init.dropout_keep_prob: 1.0
                }
                step, summaries, loss, accuracy = sess.run(
                    [global_step, dev_summary_op, cnn_init.loss, cnn_init.accuracy], feed_dic)
                time_str = datetime.datetime.now().isoformat()
                print('{}: step {}, loss {:g}, acc {:g}'.format(
                    time_str, step, loss, accuracy))

                if writer:
                    writer.add_summary(summaries, step)

            # batches = process_data_init.batch_iter(list(zip(train_data, train_label)), normal_param.batch_size,
            #                                        normal_param.num_epochs)
            batches = process_data_init.batch_iter(normal_param.batch_size, normal_param.num_epochs)
            for batch, dev_data, dev_label, is_save in batches:
            # for batch, is_save in batches:
                x_batch, y_batch = zip(*batch)
                # print("y_batch", y_batch)
                train_step(x_batch, y_batch, train_summary_write)
                current_step = tf.train.global_step(sess, global_step)
                if current_step % normal_param.evaluate_every == 0:
                    print("\nEvaluation:")
                    dev_step(dev_data, dev_label, writer=dev_summary_writer)


                if current_step % normal_param.checkpoint_every == 0:
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    print("Saved model checkpoint to {}\n".format(path))
                if is_save:
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    print("Saved model checkpoint to {}\n".format(path))