コード例 #1
0
def export_pb_model():
    vocabulary_word2index, vocabulary_index2word = create_term(
        FLAGS.term_index_path)
    vocabulary_char2index, vocabulary_index2char = create_term(
        FLAGS.char_index_path)
    vocab_size = len(vocabulary_word2index)
    char_size = len(vocabulary_char2index)
    with tf.Graph().as_default():
        model = BiGRUAttention(FLAGS.loss_number,
                               FLAGS.num_count,
                               FLAGS.num_classes,
                               FLAGS.learning_rate,
                               FLAGS.batch_size,
                               FLAGS.decay_steps,
                               FLAGS.decay_rate,
                               FLAGS.sequence_length,
                               FLAGS.chars_length,
                               FLAGS.num_sentences,
                               vocab_size,
                               char_size,
                               FLAGS.embed_size,
                               FLAGS.hidden_size,
                               FLAGS.is_training,
                               multi_label_flag=FLAGS.multi_label_flag,
                               char_attention_flag=FLAGS.char_attention_flag,
                               count_flag=FLAGS.count_flag)
        model_signature = signature_def_utils.build_signature_def(
            inputs={
                "terms_ids":
                utils.build_tensor_info(model.input_x),
                "keep_prob_hidden":
                utils.build_tensor_info(model.dropout_keep_prob)
            },
            outputs={"prediction": utils.build_tensor_info(model.predictions)},
            method_name=signature_constants.CLASSIFY_METHOD_NAME)

        session_conf = tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)

        sess = tf.Session(config=session_conf)
        saver = tf.train.Saver()
        #saver.restore(sess,_ckpt+'model.ckpt-6')
        print(tf.train.latest_checkpoint(FLAGS.ckpt))
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt))

        builder = saved_model_builder.SavedModelBuilder(FLAGS.pb_path)
        builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
                                             clear_devices=True,
                                             signature_def_map={
                                                 'cat_han_model_signature':
                                                 model_signature,
                                             })

        builder.save()
コード例 #2
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of do_train, do_eval or do_predict must be True")

    tpu_cluster_resolver = None
    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=None,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps)

    #1. load vocabulary
    vocabulary_word2index, vocabulary_index2word = create_term(
        FLAGS.term_index_path)
    vocab_size = len(vocabulary_index2word)
    print("vocab_size:", vocab_size)

    model_fn = model_fn_builder(vocab_size=vocab_size)

    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=None,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.batch_size,
        predict_batch_size=FLAGS.batch_size)

    if FLAGS.do_train:
        num_train_steps = int(
            FLAGS.train_sample_num / FLAGS.batch_size) * FLAGS.num_epochs
        print("*****all steps **************", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            num_cpu_threads=FLAGS.num_cpu_threads,
            input_file=train_sample_file,
            batch_size=FLAGS.batch_size,
            seq_length=FLAGS.sequence_length,
            first_length=FLAGS.first_length,
            second_length=FLAGS.second_length,
            third_length=FLAGS.third_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
コード例 #3
0
 def __init__(self):
     self.vocabulary_word2index, self.vocabulary_index2word = create_term(
         FLAGS.term_index_path)
     self.vocab_size = len(self.vocabulary_word2index)
     self.char_size = 100
     self.vocabulary_word2index_label, self.vocabulary_index2word_label = create_label(
         FLAGS.label_index_path)
     self.label_name = load_cid(FLAGS.label_index_path)
     FLAGS.num_classes = len(self.vocabulary_word2index_label)
     with tf.Graph().as_default():
         config = tf.ConfigProto()
         config.gpu_options.allow_growth = True
         self.sess = tf.Session(config=config)
         # 4.Instantiate Model
         self.model = BiGRUAttention(
             FLAGS.loss_number,
             FLAGS.num_count,
             FLAGS.num_classes,
             FLAGS.learning_rate,
             FLAGS.batch_size,
             FLAGS.decay_steps,
             FLAGS.decay_rate,
             FLAGS.sequence_length,
             FLAGS.chars_length,
             FLAGS.num_sentences,
             self.vocab_size,
             self.char_size,
             FLAGS.embed_size,
             FLAGS.hidden_size,
             FLAGS.is_training,
             multi_label_flag=FLAGS.multi_label_flag,
             char_attention_flag=FLAGS.char_attention_flag,
             count_flag=FLAGS.count_flag)
         self.saver = tf.train.Saver()
         if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
             print("Restoring Variables from Checkpoint")
             self.saver.restore(self.sess,
                                tf.train.latest_checkpoint(FLAGS.ckpt_dir))
         else:
             print("Can't find the checkpoint.going to stop")
コード例 #4
0
ファイル: train.py プロジェクト: Tina-ZJ/Multi-Label
def main(_):
    print("embed_size:", FLAGS.embed_size)
    print("hidden_size:", FLAGS.hidden_size)
    print("cid3_num:", FLAGS.num_classes)
    print("train_sample_num:", FLAGS.train_sample_num)
    print("dev_sample_num:", FLAGS.dev_sample_num)
    print("ckpt_dir:", FLAGS.ckpt_dir)
    #1. load vocabulary
    vocabulary_word2index, vocabulary_index2word = create_term(
        FLAGS.term_index_path)
    vocab_size = len(vocabulary_index2word)
    char_size = FLAGS.char_size
    print("vocab_size:", vocab_size)
    if FLAGS.char_attention_flag:
        vocabulary_char2index, vocabulary_index2char = create_term(
            FLAGS.char_index_path)
        char_size = len(vocabulary_index2char)
        print("char_size:", char_size)
    #2.create session.
    with tf.Graph().as_default():
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        model = BiGRUAttention(FLAGS.loss_number,
                               FLAGS.num_count,
                               FLAGS.num_classes,
                               FLAGS.learning_rate,
                               FLAGS.batch_size,
                               FLAGS.decay_steps,
                               FLAGS.decay_rate,
                               FLAGS.sequence_length,
                               FLAGS.chars_length,
                               FLAGS.num_sentences,
                               vocab_size,
                               char_size,
                               FLAGS.embed_size,
                               FLAGS.hidden_size,
                               FLAGS.is_training,
                               multi_label_flag=FLAGS.multi_label_flag,
                               char_attention_flag=FLAGS.char_attention_flag,
                               count_flag=FLAGS.count_flag)
        train_batcher = batch_read_tfrecord.SegBatcher(
            FLAGS.train_sample_file,
            FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs)
        dev_batcher = batch_read_tfrecord.SegBatcher(
            FLAGS.dev_sample_file,
            FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs)
        global_init_op = tf.global_variables_initializer()
        local_init_op = tf.local_variables_initializer()
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            #tensorboard
            train_writer = tf.summary.FileWriter(FLAGS.summary_dir + 'train/',
                                                 sess.graph)
            dev_writer = tf.summary.FileWriter(FLAGS.summary_dir + 'dev/')
            sess.run(global_init_op)
            sess.run(local_init_op)
            if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
                print("Restoring Variables from Checkpoint")
                saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
            else:
                print('Initializing Variables')
                if FLAGS.use_embedding:
                    assign_pretrained_word_embedding(
                        sess,
                        vocabulary_index2word,
                        vocab_size,
                        model,
                        word2vec_model_path=FLAGS.word2vec_model_path)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            curr_epoch = sess.run(model.epoch_step)
            #3.feed data & training
            train_sample_num = FLAGS.train_sample_num
            dev_sample_num = FLAGS.dev_sample_num
            best_eval_f1 = 0

            for epoch in range(curr_epoch, FLAGS.num_epochs):
                loss, acc, recall, precision, f1, counter = 0.0, 0.0, 0.0, 0.0, 0.0, 0
                eval_loss, eval_acc, eval_recall, eval_precision, eval_f1, counter_dev = 0.0, 0.0, 0.0, 0.0, 0.0, 0
                train_example_num = 0
                dev_example_num = 0
                while train_example_num < train_sample_num:
                    try:
                        train_batch_data = sess.run(
                            train_batcher.next_batch_op)
                        trainX, trainXChar, trainY = train_batch_data
                        if len(trainX) != FLAGS.batch_size:
                            continue
                        #trainX = pad_sequences(trainX, maxlen=FLAGS.sequence_length, value=0.)
                        trainY = common.get_one_hot_label(
                            trainY, FLAGS.num_classes)
                        if FLAGS.char_attention_flag:
                            #trainXChar = pad_sequences(trainXChar, maxlen=FLAGS.chars_length, value=0.)
                            feed_dict = {
                                model.input_x: trainX,
                                model.input_char: trainXChar,
                                model.dropout_keep_prob: 0.5
                            }
                        else:
                            feed_dict = {
                                model.input_x: trainX,
                                model.dropout_keep_prob: 0.5
                            }

                        if not FLAGS.multi_label_flag:
                            feed_dict[model.input_y] = trainY
                        else:
                            feed_dict[model.input_y_multilabel] = trainY
                        summary_merge, global_step, curr_loss, curr_acc, curr_f1, curr_recall, curr_precision, _ = sess.run(
                            [
                                model.summary_merge, model.global_step,
                                model.loss_val, model.accuracy, model.f1,
                                model.recall, model.precision, model.train_op
                            ], feed_dict)

                        train_writer.add_summary(summary_merge, global_step)

                        loss, counter, acc, recall, precision, f1 = loss + curr_loss, counter + 1, acc + curr_acc, recall + curr_recall, precision + curr_precision, f1 + curr_f1
                        train_example_num += FLAGS.batch_size
                        if counter % 2000 == 0:
                            print(
                                "Epoch %d\tBatch %d\tTrain Loss:%.3f\tTrain Accuracy:%.3f\tTrain precision:%.3f\tTrain recall:%.3f\tTrain f1:%.3f"
                                %
                                (epoch, counter, loss / float(counter), acc /
                                 float(counter), precision / float(counter),
                                 recall / float(counter), f1 / float(counter)))
                    except tf.errors.OutOfRangeError:
                        print("Done Training")
                        break

                    ##VALIDATION VALIDATION VALIDATION PART######################################################################################################
                while dev_example_num < dev_sample_num:
                    counter_dev += 1
                    try:
                        dev_batch_data = sess.run(dev_batcher.next_batch_op)
                        testX, testXChar, testY = dev_batch_data
                        if len(testX) != FLAGS.batch_size:
                            continue
                        #testX = pad_sequences(testX, maxlen=FLAGS.sequence_length, value=0.)
                        testY = common.get_one_hot_label(
                            testY, FLAGS.num_classes)
                        if FLAGS.char_attention_flag:
                            testXChar = pad_sequences(
                                testXChar, maxlen=FLAGS.chars_length, value=0.)
                        else:
                            testXChar = []
                        cur_eval_loss, cur_eval_acc, cur_eval_precision, cur_eval_recall, cur_eval_f1 = do_eval(
                            dev_writer, sess, model, testX, testXChar, testY)
                        eval_loss, eval_acc, eval_recall, eval_precision, eval_f1 = eval_loss + cur_eval_loss, eval_acc + cur_eval_acc, eval_recall + cur_eval_recall, eval_precision + cur_eval_precision, eval_f1 + cur_eval_f1
                        dev_example_num += FLAGS.batch_size
                        if counter_dev % FLAGS.validate_step == 0:
                            print(
                                "Epoch %d \tBatch %d\tValidation Loss:%.3f\tValidation Accuracy: %.3f\tValidation precision:%.3f\tValidation recall:%.3f\tValidation f1:%.3f\t"
                                % (epoch, counter_dev,
                                   eval_loss / float(counter_dev),
                                   eval_acc / float(counter_dev),
                                   eval_precision / float(counter_dev),
                                   eval_recall / float(counter_dev),
                                   eval_f1 / float(counter_dev)))
                            if eval_f1 > best_eval_f1:
                                best_eval_f1 = eval_f1
                                print("Validation best f1: %f" % best_eval_f1)
                    except tf.errors.OutOfRangeError:
                        print("Done test")
                        break

            ##VALIDATION VALIDATION VALIDATION PART######################################################################################################

            #epoch increment
                print("going to increment epoch counter....")
                sess.run(model.epoch_increment)

                # 4.validation
                print(epoch, FLAGS.validate_every,
                      (epoch % FLAGS.validate_every == 0))
                if epoch % FLAGS.validate_every == 0:
                    #save model to checkpoint
                    if not os.path.exists(FLAGS.ckpt_dir):
                        os.makedirs(FLAGS.ckpt_dir)
                    save_path = FLAGS.ckpt_dir + "model.ckpt"
                    saver.save(sess, save_path, global_step=epoch)
        coord.request_stop()
        coord.join(threads)
        sess.close()
コード例 #5
0
ファイル: predict.py プロジェクト: Tina-ZJ/Multi-Label
def main(_):
    # 1.load data with vocabulary of words and labels
    vocabulary_word2index, vocabulary_index2word = create_term(
        FLAGS.term_index_path)
    vocabulary_char2index, vocabulary_index2char = create_term(
        FLAGS.char_index_path)
    vocab_size = len(vocabulary_word2index)
    char_size = len(vocabulary_char2index)
    vocabulary_word2index_label, vocabulary_index2word_label = create_label(
        FLAGS.label_index_path)
    testX, testXChar, lines = load_test(FLAGS.predict_target_file,
                                        vocabulary_word2index,
                                        vocabulary_char2index)
    FLAGS.num_classes = len(vocabulary_word2index_label)
    print("start padding....")
    #testX = pad_sequences(testX, maxlen=FLAGS.sequence_length, value=0.)  # padding to max length
    if FLAGS.char_attention_flag:
        testXChar = pad_sequences(testXChar,
                                  maxlen=FLAGS.chars_length,
                                  value=0.)
    print("end padding...")
    # 2.create session.
    with tf.Graph().as_default():
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            model = BiGRUAttention(
                FLAGS.loss_number,
                FLAGS.num_count,
                FLAGS.num_classes,
                FLAGS.learning_rate,
                FLAGS.batch_size,
                FLAGS.decay_steps,
                FLAGS.decay_rate,
                FLAGS.sequence_length,
                FLAGS.chars_length,
                FLAGS.num_sentences,
                vocab_size,
                char_size,
                FLAGS.embed_size,
                FLAGS.hidden_size,
                FLAGS.is_training,
                multi_label_flag=FLAGS.multi_label_flag,
                char_attention_flag=FLAGS.char_attention_flag,
                count_flag=FLAGS.count_flag)
            saver = tf.train.Saver()
            if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
                print("Restoring Variables from Checkpoint")
                saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
                #saver.restore(sess,FLAGS.ckpt_dir+'model.ckpt-2')
            else:
                print("Can't find the checkpoint.going to stop")
                return
            # 5.feed data, to get logits
            number_of_training_data = len(testX)
            print("number_of_training_data:", number_of_training_data)
            predict_target_file_f = codecs.open(FLAGS.predict_source_file, 'a',
                                                'utf8')
            label_name = load_cid(FLAGS.label_index_path)
            t0 = time.time()
            for start, end in zip(
                    range(0, number_of_training_data, FLAGS.batch_size),
                    range(FLAGS.batch_size, number_of_training_data + 1,
                          FLAGS.batch_size)):
                if FLAGS.char_attention_flag:
                    predictions = sess.run(model.predictions,
                                           feed_dict={
                                               model.input_x:
                                               testX[start:end],
                                               model.input_char:
                                               testXChar[start:end],
                                               model.dropout_keep_prob:
                                               1
                                           })
                else:
                    attention, predictions = sess.run(
                        [model.attention, model.predictions],
                        feed_dict={
                            model.input_x: testX[start:end],
                            model.dropout_keep_prob: 1
                        })
                lines_sublist = lines[start:end]
                get_label_using_logits_batch(lines_sublist, attention,
                                             predictions,
                                             vocabulary_index2word_label,
                                             predict_target_file_f,
                                             FLAGS.threshold, label_name)
            t1 = time.time()
            print("all running time: %s " % str(t1 - t0))
            predict_target_file_f.close()
コード例 #6
0
    return model_fn


def serving_input_receiver_fn():
    input_x = tf.placeholder(dtype=tf.int32, shape=[None, None], name='input_x')
    receiver_tensors = {'input_x': input_x}
    features = {'input_x': input_x}
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)


 
        
if __name__=='__main__':
    
    #1. load vocabulary
    vocabulary_word2index, vocabulary_index2word= create_term(FLAGS.term_index_path)
    vocab_size = len(vocabulary_index2word)
    print("vocab_size:",vocab_size)
    cp_file = tf.train.latest_checkpoint(FLAGS.output_dir)
    model_fn = model_fn_builder(
            vocab_size=vocab_size)
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 1.0
    config.log_device_placement = False
    batch_size = 1
    export_dir = FLAGS.output_dir 
    estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=FLAGS.output_dir, config=RunConfig(session_config=config),
                                                params={'batch_size': batch_size})
    estimator.export_saved_model(export_dir, serving_input_receiver_fn, checkpoint_path=cp_file) 
コード例 #7
0
def predict():
    # 1.load data with vocabulary of words and labels
    vocabulary_word2index, vocabulary_index2word = create_term(
        FLAGS.term_index_path)
    vocab_size = len(vocabulary_word2index)
    testX, lines = load_test(FLAGS.predict_target_file, vocabulary_word2index)
    f = open(FLAGS.predict_source_file, 'w')

    # for tags index
    #tags_index2word = load_tag(FLAGS.tags_index_path)

    # id2name
    cid2name = load_cid(FLAGS.label_name_path)
    brand2name = load_cid(FLAGS.brand_name_path)
    product2name = load_cid(FLAGS.product_name_path)

    # id2index
    cid2index, index2cid = create_label(FLAGS.label_index_path)
    product2index, index2product = create_label(FLAGS.product_index_path)
    brand2index, index2brand = create_label(FLAGS.brand_index_path)

    #path
    subdirs = [
        x for x in Path(FLAGS.ckpt_dir).iterdir()
        if x.is_dir() and 'temp' not in str(x)
    ]
    model_pb = str(sorted(subdirs)[-1])
    predict_fn = predictor.from_saved_model(model_pb)

    #evl
    for i, x in enumerate(testX):
        print(x)
        feed_dict = {'input_x': [x]}
        result = predict_fn(feed_dict)
        batch_predictions = result['predictions']
        batch_product_predictions = result['product_predictions']
        batch_brand_predictions = result['brand_predictions']
        for predictions, product_predictions, brand_predictions in zip(
                batch_predictions, batch_product_predictions,
                batch_brand_predictions):
            predictions_sorted = sorted(predictions, reverse=True)
            product_predictions_sorted = sorted(product_predictions,
                                                reverse=True)
            brand_predictions_sorted = sorted(brand_predictions, reverse=True)

            index_sorted = np.argsort(-predictions)
            product_index_sorted = np.argsort(-product_predictions)
            brand_index_sorted = np.argsort(-brand_predictions)

            label_list, label_name = get_result(index_sorted,
                                                predictions_sorted, index2cid,
                                                cid2name)
            product_list, product_name = get_result(
                product_index_sorted, product_predictions_sorted,
                index2product, product2name)
            brand_list, brand_name = get_result(brand_index_sorted,
                                                brand_predictions_sorted,
                                                index2brand, brand2name)

            f.write(lines[i] + '\t' + ','.join(label_list) + '\t' +
                    ','.join(label_name) + '\t' + ','.join(product_list) +
                    '\t' + ','.join(product_name) + '\t' +
                    ','.join(brand_list) + '\t' + ','.join(brand_name) + '\n')
    f.flush()