def setup_tensorflow(): """Restores a tensorflow session and returns it if successful """ net_model = NNModel() tf_config = tf.ConfigProto(device_count={'GPU': config.should_use_gpu}) sess = tf.Session(config=tf_config) # Add ops to save and restore all of the variables saver = tf.train.Saver() # Load the model checkpoint file try: tmp_file = config.tf_checkpoint_file print("Loading model from config: {}".format(tmp_file)) except: tmp_file = config.load( 'last_tf_model') #gets the cached last tf trained model print "loading latest trained model: " + str(tmp_file) # print("CAN'T FIND THE GOOD MODEL") # sys.exit(-1) # Try to restore a session try: saver.restore(sess, tmp_file) except: print("Error restoring TF model: {}".format(tmp_file)) # sys.exit(-1) return sess, net_model
import sys, os import matplotlib import numpy as np from matplotlib.pylab import * matplotlib.use('Agg') sys.path.append('..') import tensorflow as tf import config from NeuralNet.convnetshared1 import NNModel from NeuralNet.data_model import TrainingData if __name__ == '__main__': net_model = NNModel() tf_config = tf.ConfigProto(device_count={'GPU': config.should_use_gpu}) sess = tf.Session(config=tf_config) # Add ops to save and restore all of the variables saver = tf.train.Saver() # Load the model checkpoint file try: tmp_file = config.tf_checkpoint_file print("Loading model from config: {}".format(tmp_file)) except: tmp_file = config.load( 'last_tf_model') #gets the cached last tf trained model print "loading latest trained model: " + str(tmp_file) # print("CAN'T FIND THE GOOD MODEL") # sys.exit(-1)