예제 #1
0
 def load_model():
     with tf.Graph().as_default():
         sess = tf.Session()
         with sess.as_default():
             albert =  NetworkAlbertTextCNN(is_training=False)
             saver = tf.train.Saver()  
             sess.run(tf.global_variables_initializer())
             checkpoint_dir = os.path.abspath(os.path.join(pwd,hp.file_load_model))
             print (checkpoint_dir)
             ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
             saver.restore(sess, ckpt.model_checkpoint_path)
     return albert,sess
예제 #2
0
Created on Thu May 30 21:42:07 2019

@author: cm
"""

import os
#os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
import numpy as np
import tensorflow as tf
from classifier_multi_label_textcnn.networks import NetworkAlbertTextCNN
from classifier_multi_label_textcnn.classifier_utils import get_features
from classifier_multi_label_textcnn.hyperparameters import Hyperparamters as hp
from classifier_multi_label_textcnn.utils import select, time_now_string

pwd = os.path.dirname(os.path.abspath(__file__))
MODEL = NetworkAlbertTextCNN(is_training=True)

# Get data features
input_ids, input_masks, segment_ids, label_ids = get_features()
num_train_samples = len(input_ids)
indexs = np.arange(num_train_samples)
num_batchs = int((num_train_samples - 1) / hp.batch_size) + 1
print('Number of batch:', num_batchs)

# Set up the graph
saver = tf.train.Saver(max_to_keep=hp.max_to_keep)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Load model saved before
MODEL_SAVE_PATH = os.path.join(pwd, hp.file_save_model)
예제 #3
0
Created on Thu May 30 21:42:07 2019

@author: cm
"""

import os
# os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
import numpy as np
import tensorflow as tf
from classifier_multi_label_textcnn.networks import NetworkAlbertTextCNN
from classifier_multi_label_textcnn.classifier_utils import get_features,get_features_test
from classifier_multi_label_textcnn.hyperparameters import Hyperparamters as hp
from classifier_multi_label_textcnn.utils import select, time_now_string

pwd = os.path.dirname(os.path.abspath(__file__))
MODEL = NetworkAlbertTextCNN(is_training=False)

# Get data features
input_ids, input_masks, segment_ids, label_ids = get_features_test()
num_train_samples = len(input_ids)
indexs = np.arange(num_train_samples)
num_batchs = int((num_train_samples - 1) / hp.batch_size) + 1  # 800 / 64 = 13
print('Number of batch:', num_batchs)

# Set up the graph
saver = tf.train.Saver(max_to_keep=hp.max_to_keep)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Load model saved before
# MODEL_SAVE_PATH = os.path.join(pwd, hp.file_save_model)