def restore_model(self): file_name = os.path.splitext(os.path.abspath(self.model_name))[0] num_files = len( sorted(glob.glob(os.path.abspath(file_name + '*.meta')))) if num_files > 0: checkpoint_file = os.path.abspath( sorted(glob.glob(file_name + '*.data-00000-of-00001'), reverse=True)[0]) if os.path.exists(checkpoint_file): print('Restoring model from %s' % checkpoint_file) meta_file = os.path.abspath( sorted(glob.glob(file_name + '*.meta'), reverse=True)[0]) print('Loading: %s' % meta_file) saver = tf.train.import_meta_graph(meta_file) print('Loading: %s' % os.path.abspath(checkpoint_file)) cpk = tf.train.latest_checkpoint(os.path.dirname(meta_file)) print('Checkpoint: ' + str(cpk)) print('Tensors') print( print_tensors_in_checkpoint_file(file_name=cpk, all_tensors='', tensor_name='')) saver.restore( self.session, tf.train.latest_checkpoint(os.path.dirname(meta_file))) print('Last epoch to restore: ' + str(self.session.run(self.global_step))) else: file_utils.delete_all_files_in_dir(self.logging_dir) print('Restoring cannot be done')
def create_graph(self, image_shape, num_classes): start = time.time() self.image_shape = image_shape self.num_classes = num_classes self.num_layers = len(self.config.keys()) - 1 self.global_step = tf.Variable(0, name='last_successful_epoch', trainable=False, dtype=tf.int32) self.last_epoch = tf.assign(self.global_step, self.global_step + 1, name='assign_updated_epoch') # Step 1: Creating placeholders for inputs self.make_placeholders_for_inputs() # Step 2: Creating initial parameters for the variables self.make_parameters() # Step 3: Make predictions for the data self.make_predictions() # Step 4: Perform optimization operation self.make_optimization() # Step 5: Calculate accuracies self.make_accuracy() # Step 6: Initialize all the required variables with tf.device(self.device): self.init_var = tf.global_variables_initializer() # Step 7: Initiate Session config = tf.ConfigProto( log_device_placement=True, allow_soft_placement=True, ) # config.gpu_options.allow_growth = True # config.gpu_options.per_process_gpu_memory_fraction = 0.4 if self.session_type == 'default': self.session = tf.Session(config=config) if self.session_type == 'interactive': self.session = tf.InteractiveSession(config=config) print('Session: ' + str(self.session)) self.session.run(self.init_var) # Step 8: Initiate logs if self.logging is True: file_utils.mkdir_p(self.logging_dir) self.merged_summary_op = tf.summary.merge_all() if self.restore is False: file_utils.delete_all_files_in_dir(self.logging_dir) self.summary_writer = \ tf.summary.FileWriter(self.logging_dir, graph=self.session.graph) if self.save_checkpoint is True: self.model = tf.train.Saver(max_to_keep=1) # Step 9: Restore model if self.restore is True: self.restore_model() epoch = self.session.run(self.global_step) print('Model has been trained for %d iterations' % epoch) end = time.time() print('Tensorflow graph created in %.4f seconds' % (end - start)) return True
def start_session(self, graph): # Step 1: Initialize all the required variables with graph.as_default(): self.global_step = tf.Variable(0, name='last_successful_epoch', trainable=False, dtype=tf.int32) self.last_epoch = tf.assign(self.global_step, self.global_step + 1, name='assign_updated_epoch') with tf.device(self.device): self.init_var = tf.global_variables_initializer() # Step 2: Initiate Session config = tf.ConfigProto( log_device_placement=True, allow_soft_placement=True, ) if self.session_type == 'default': self.session = tf.Session(config=config, graph=graph) if self.session_type == 'interactive': self.session = tf.InteractiveSession(config=config, graph=graph) print('Session: ' + str(self.session)) self.session.run(self.init_var) # Step 3: Initiate logs if self.logging is True: file_utils.mkdir_p(self.logging_dir) self.merged_summary_op = tf.summary.merge_all() if self.restore is False: file_utils.delete_all_files_in_dir(self.logging_dir) self.summary_writer = \ tf.summary.FileWriter(self.logging_dir, graph=graph) if self.save_model is True: self.model = tf.train.Saver(max_to_keep=2) # Step 4: Restore model if self.restore is True: self.restore_model() epoch = self.session.run(self.global_step) print('Model has been trained for %d iterations' % epoch) return True
def fit(self, data, labels, classes, test_data=None, test_labels=None, test_classes=None): if self.device == 'cpu': print('Using CPU') config = tf.ConfigProto( log_device_placement=True, allow_soft_placement=True, #allow_growth=True, #device_count={'CPU': 0} ) else: print('Using GPU') config = tf.ConfigProto( log_device_placement=True, allow_soft_placement=True, #allow_growth=True, #device_count={'GPU': 0} ) if self.session_type == 'default': self.session = tf.Session(config=config) if self.session_type == 'interactive': self.session = tf.InteractiveSession(config=config) print('Session: ' + str(self.session)) self.session.run(self.init_var) if self.tensorboard_logs is True: file_utils.mkdir_p(self.tensorboard_log_dir) self.merged_summary_op = tf.summary.merge_all() if self.restore is False: file_utils.delete_all_files_in_dir(self.tensorboard_log_dir) if self.separate_writer is False: self.summary_writer = tf.summary.FileWriter( self.tensorboard_log_dir, graph=self.session.graph) else: self.train_writer = tf.summary.FileWriter( self.tensorboard_log_dir + 'train', graph=self.session.graph) if self.train_validate_split is not None: self.validate_writer = tf.summary.FileWriter( self.tensorboard_log_dir + 'validate', graph=self.session.graph) if self.test_log is True: self.test_writer = tf.summary.FileWriter( self.tensorboard_log_dir + 'test', graph=self.session.graph) if self.save_model is True: self.model = tf.train.Saver(max_to_keep=5) if self.train_validate_split is not None: train_data, validate_data, train_labels, validate_labels, train_classes, validate_classes = \ train_test_split(data, labels, classes, train_size=self.train_validate_split) if self.verbose is True: print('Data shape: ' + str(data.shape)) print('Labels shape: ' + str(labels.shape)) print('Classes shape: ' + str(classes.shape)) print('Train Data shape: ' + str(train_data.shape)) print('Train Labels shape: ' + str(train_labels.shape)) print('Train Classes shape: ' + str(train_classes.shape)) print('Validate Data shape: ' + str(validate_data.shape)) print('Validate Labels shape: ' + str(validate_labels.shape)) print('Validate Classes shape: ' + str(validate_classes.shape)) if self.test_log is False: self.optimize(train_data, train_labels, train_classes, validate_data=validate_data, validate_labels=validate_labels, validate_classes=validate_classes) else: self.optimize(train_data, train_labels, train_classes, validate_data=validate_data, validate_labels=validate_labels, validate_classes=validate_classes, test_data=test_data, test_labels=test_labels, test_classes=test_classes) else: if self.test_log is False: self.optimize(data, labels, classes) else: self.optimize(data, labels, classes, test_data=test_data, test_labels=test_labels, test_classes=test_classes)