def load_model(): with tf.Graph().as_default(): sess = tf.Session() with sess.as_default(): albert = NetworkAlbertTextCNN(is_training=False) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) checkpoint_dir = os.path.abspath(os.path.join(pwd,hp.file_load_model)) print (checkpoint_dir) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) saver.restore(sess, ckpt.model_checkpoint_path) return albert,sess
Created on Thu May 30 21:42:07 2019 @author: cm """ import os #os.environ["CUDA_VISIBLE_DEVICES"] = '-1' import numpy as np import tensorflow as tf from classifier_multi_label_textcnn.networks import NetworkAlbertTextCNN from classifier_multi_label_textcnn.classifier_utils import get_features from classifier_multi_label_textcnn.hyperparameters import Hyperparamters as hp from classifier_multi_label_textcnn.utils import select, time_now_string pwd = os.path.dirname(os.path.abspath(__file__)) MODEL = NetworkAlbertTextCNN(is_training=True) # Get data features input_ids, input_masks, segment_ids, label_ids = get_features() num_train_samples = len(input_ids) indexs = np.arange(num_train_samples) num_batchs = int((num_train_samples - 1) / hp.batch_size) + 1 print('Number of batch:', num_batchs) # Set up the graph saver = tf.train.Saver(max_to_keep=hp.max_to_keep) sess = tf.Session() sess.run(tf.global_variables_initializer()) # Load model saved before MODEL_SAVE_PATH = os.path.join(pwd, hp.file_save_model)
Created on Thu May 30 21:42:07 2019 @author: cm """ import os # os.environ["CUDA_VISIBLE_DEVICES"] = '-1' import numpy as np import tensorflow as tf from classifier_multi_label_textcnn.networks import NetworkAlbertTextCNN from classifier_multi_label_textcnn.classifier_utils import get_features,get_features_test from classifier_multi_label_textcnn.hyperparameters import Hyperparamters as hp from classifier_multi_label_textcnn.utils import select, time_now_string pwd = os.path.dirname(os.path.abspath(__file__)) MODEL = NetworkAlbertTextCNN(is_training=False) # Get data features input_ids, input_masks, segment_ids, label_ids = get_features_test() num_train_samples = len(input_ids) indexs = np.arange(num_train_samples) num_batchs = int((num_train_samples - 1) / hp.batch_size) + 1 # 800 / 64 = 13 print('Number of batch:', num_batchs) # Set up the graph saver = tf.train.Saver(max_to_keep=hp.max_to_keep) sess = tf.Session() sess.run(tf.global_variables_initializer()) # Load model saved before # MODEL_SAVE_PATH = os.path.join(pwd, hp.file_save_model)