示例#1
0
 def create_graph(self, image_shape, num_classes):
     start = time.time()
     self.image_shape = image_shape
     self.num_classes = num_classes
     self.num_layers = len(self.config.keys()) - 1
     self.global_step = tf.Variable(0,
                                    name='last_successful_epoch',
                                    trainable=False,
                                    dtype=tf.int32)
     self.last_epoch = tf.assign(self.global_step,
                                 self.global_step + 1,
                                 name='assign_updated_epoch')
     # Step 1: Creating placeholders for inputs
     self.make_placeholders_for_inputs()
     # Step 2: Creating initial parameters for the variables
     self.make_parameters()
     # Step 3: Make predictions for the data
     self.make_predictions()
     # Step 4: Perform optimization operation
     self.make_optimization()
     # Step 5: Calculate accuracies
     self.make_accuracy()
     # Step 6: Initialize all the required variables
     with tf.device(self.device):
         self.init_var = tf.global_variables_initializer()
     # Step 7: Initiate Session
     config = tf.ConfigProto(
         log_device_placement=True,
         allow_soft_placement=True,
     )
     # config.gpu_options.allow_growth = True
     # config.gpu_options.per_process_gpu_memory_fraction = 0.4
     if self.session_type == 'default':
         self.session = tf.Session(config=config)
     if self.session_type == 'interactive':
         self.session = tf.InteractiveSession(config=config)
     print('Session: ' + str(self.session))
     self.session.run(self.init_var)
     # Step 8: Initiate logs
     if self.logging is True:
         file_utils.mkdir_p(self.logging_dir)
         self.merged_summary_op = tf.summary.merge_all()
         if self.restore is False:
             file_utils.delete_all_files_in_dir(self.logging_dir)
         self.summary_writer = \
             tf.summary.FileWriter(self.logging_dir, graph=self.session.graph)
         if self.save_checkpoint is True:
             self.model = tf.train.Saver(max_to_keep=1)
     # Step 9: Restore model
     if self.restore is True:
         self.restore_model()
     epoch = self.session.run(self.global_step)
     print('Model has been trained for %d iterations' % epoch)
     end = time.time()
     print('Tensorflow graph created in %.4f seconds' % (end - start))
     return True
示例#2
0
 def start_session(self, graph):
     # Step 1: Initialize all the required variables
     with graph.as_default():
         self.global_step = tf.Variable(0,
                                        name='last_successful_epoch',
                                        trainable=False,
                                        dtype=tf.int32)
         self.last_epoch = tf.assign(self.global_step,
                                     self.global_step + 1,
                                     name='assign_updated_epoch')
         with tf.device(self.device):
             self.init_var = tf.global_variables_initializer()
     # Step 2: Initiate Session
     config = tf.ConfigProto(
         log_device_placement=True,
         allow_soft_placement=True,
     )
     if self.session_type == 'default':
         self.session = tf.Session(config=config, graph=graph)
     if self.session_type == 'interactive':
         self.session = tf.InteractiveSession(config=config, graph=graph)
     print('Session: ' + str(self.session))
     self.session.run(self.init_var)
     # Step 3: Initiate logs
     if self.logging is True:
         file_utils.mkdir_p(self.logging_dir)
         self.merged_summary_op = tf.summary.merge_all()
         if self.restore is False:
             file_utils.delete_all_files_in_dir(self.logging_dir)
         self.summary_writer = \
             tf.summary.FileWriter(self.logging_dir, graph=graph)
         if self.save_model is True:
             self.model = tf.train.Saver(max_to_keep=2)
     # Step 4: Restore model
     if self.restore is True:
         self.restore_model()
     epoch = self.session.run(self.global_step)
     print('Model has been trained for %d iterations' % epoch)
     return True
示例#3
0
 def download_and_extract_data(self, data_directory):
     """
     
     :param data_directory: 
     :return: 
     """
     print('Downloading and extracting CIFAR 100 file')
     ## Step 1: Make the directories './datasets/cifar100/' if they do not exist
     if not os.path.exists(data_directory):
         if self.verbose is True:
             print('Creating the directory \'%s\'' % data_directory)
         file_utils.mkdir_p(data_directory)
     else:
         if self.verbose is True:
             print('Directory \'%s\' already exists' % data_directory)
     ## Step 2: Check if './datasets/cifar100/cifar-100-python.tar.gz' exists
     tar_file = data_directory + 'cifar-100.tar.gz'
     make_tar = False
     if not os.path.exists(tar_file):
         make_tar = True
     elif os.path.exists(tar_file) and not file_utils.verify_md5(
             tar_file, self.file_md5):
         if self.verbose is True:
             print('Removing the wrong file \'%s\'' % tar_file)
         os.remove(tar_file)
         make_tar = True
     else:
         if self.verbose is True:
             print('CIFAR 100 tarfile exists and MD5 sum is verified')
     ## Step 3: Download CIFAR 100 dataset from 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
     if make_tar is True:
         result = file_utils.download(self.file_url,
                                      tar_file,
                                      verbose=self.verbose)
         if result is False:
             if self.verbose is True:
                 raise FileNotFoundError(
                     'Download of CIFAR 100 dataset failed')
             return False
         result = file_utils.verify_md5(tar_file,
                                        self.file_md5,
                                        verbose=self.verbose)
         if result is False:
             if self.verbose is True:
                 raise FileNotFoundError(
                     'Downloaded CIFAR 100 dataset failed md5sum check')
             return False
     ## Step 4: Extract the dataset
     make_extract = False
     batches_directory = data_directory + 'cifar-100-batches'
     if not os.path.exists(batches_directory):
         make_extract = True
     else:
         num_files = sum(
             os.path.isfile(os.path.join(batches_directory, f))
             for f in os.listdir(batches_directory))
         if num_files != 8:
             shutil.rmtree(batches_directory)
             make_extract = True
         else:
             if self.verbose is True:
                 print('Directory %s already exists' % batches_directory)
     if make_extract is True:
         print('Extracting file %s to %s' % (tar_file, batches_directory))
         result = file_utils.extract(tar_file)
         shutil.move('./cifar-100-python', batches_directory)
         if result is False:
             if self.verbose is True:
                 print('Extraction of CIFAR 100 dataset failed')
             return False
         else:
             if self.verbose is True:
                 print('Extraction of CIFAR 100 dataset success')
     return True
 def run(self,
         train_data,
         train_labels,
         train_classes,
         validate_data=None,
         validate_labels=None,
         validate_classes=None,
         test_data=None,
         test_labels=None,
         test_classes=None):
     if self.train_validate_split is not None:
         feed_dict_validate = {
             self.x: validate_data,
             self.y_true: validate_labels,
             self.y_true_cls: validate_classes
         }
     if self.test_log is True:
         feed_dict_test = {
             self.x: test_data,
             self.y_true: test_labels,
             self.y_true_cls: test_classes
         }
     epoch = self.session.run(self.global_step)
     print('Last successful epoch: ' + str(epoch))
     converged = False
     prev_cost = 0
     start = time.time()
     end_batch_index = 0
     num_batches = int(train_data.shape[0] / self.batch_size)
     while (epoch != self.max_iterations) and converged is False:
         start_batch_index = 0
         for batch in range(num_batches):
             # print('Training on batch %d' %batch)
             end_batch_index = start_batch_index + self.batch_size
             if end_batch_index < train_data.shape[0]:
                 train_batch_data = train_data[
                     start_batch_index:end_batch_index, :]
                 train_batch_labels = train_labels[
                     start_batch_index:end_batch_index, :]
                 train_batch_classes = train_classes[
                     start_batch_index:end_batch_index]
             else:
                 train_batch_data = train_data[start_batch_index:, :]
                 train_batch_labels = train_labels[start_batch_index:, :]
                 train_batch_classes = train_classes[start_batch_index:]
             feed_dict_train = {
                 self.x: train_batch_data,
                 self.y_true: train_batch_labels,
                 self.y_true_cls: train_batch_classes
             }
             _, cost, train_acc, curr_epoch = self.session.run(
                 [
                     self.optimizer, self.loss, self.train_accuracy,
                     self.last_epoch
                 ],
                 feed_dict=feed_dict_train)
             train_loss_summary = self.session.run(
                 self.train_loss_summary, feed_dict=feed_dict_train)
             train_acc_summary = self.session.run(self.train_acc_summary,
                                                  feed_dict=feed_dict_train)
             start_batch_index += self.batch_size
         if self.train_validate_split is not None:
             validate_acc, validate_summary = \
             self.session.run([self.validate_accuracy, self.validate_acc_summary],
                              feed_dict=feed_dict_validate)
         if self.test_log is True:
             test_acc, test_summary = \
                 self.session.run([self.test_accuracy, self.test_acc_summary],
                                  feed_dict=feed_dict_test)
         if self.separate_writer is False:
             self.summary_writer.add_summary(train_loss_summary, epoch)
             self.summary_writer.add_summary(train_acc_summary, epoch)
             self.summary_writer.add_summary(validate_summary, epoch)
             self.summary_writer.add_summary(test_summary, epoch)
         else:
             self.train_writer.add_summary(train_loss_summary, epoch)
             self.train_writer.add_summary(train_acc_summary, epoch)
             if self.train_validate_split is not None:
                 self.validate_writer.add_summary(validate_summary, epoch)
             if self.test_log is True:
                 self.test_writer.add_summary(test_summary, epoch)
         if epoch % self.display_step == 0:
             duration = time.time() - start
             if self.train_validate_split is not None and self.test_log is False:
                 print(
                     '>>> Epoch [%*d/%*d] | Error: %.4f | Train Acc.: %.4f | Validate Acc.: %.4f | '
                     'Duration: %.4f seconds' %
                     (int(len(str(self.max_iterations))), epoch,
                      int(len(str(
                          self.max_iterations))), self.max_iterations, cost,
                      train_acc, validate_acc, duration))
             elif self.train_validate_split is not None and self.test_log is True:
                 print(
                     '>>> Epoch [%*d/%*d] | Error: %.4f | Train Acc.: %.4f | Validate Acc.: %.4f | '
                     'Test Acc.: %.4f | Duration: %.4f seconds' %
                     (int(len(str(self.max_iterations))), epoch,
                      int(len(str(
                          self.max_iterations))), self.max_iterations, cost,
                      train_acc, validate_acc, test_acc, duration))
             elif self.train_validate_split is None and self.test_log is True:
                 print(
                     '>>> Epoch [%*d/%*d] | Error: %.4f | Train Acc.: %.4f | '
                     'Test Acc.: %.4f | Duration: %.4f seconds' %
                     (int(len(str(self.max_iterations))), epoch,
                      int(len(str(
                          self.max_iterations))), self.max_iterations, cost,
                      train_acc, test_acc, duration))
             else:
                 print(
                     '>>> Epoch [%*d/%*d] | Error: %.4f | Train Acc.: %.4f | Duration of run: %.4f seconds'
                     % (int(len(str(self.max_iterations))), epoch,
                        int(len(str(self.max_iterations))),
                        self.max_iterations, cost, train_acc))
         start = time.time()
         if self.save_model is True:
             model_directory = os.path.dirname(self.model_name)
             file_utils.mkdir_p(model_directory)
             self.model.save(self.session,
                             self.model_name,
                             global_step=epoch)
         if epoch == 0:
             prev_cost = cost
         else:
             if math.fabs(cost - prev_cost) < self.tolerance:
                 converged = False
         epoch += 1
 def fit(self,
         data,
         labels,
         classes,
         test_data=None,
         test_labels=None,
         test_classes=None):
     if self.device == 'cpu':
         print('Using CPU')
         config = tf.ConfigProto(
             log_device_placement=True,
             allow_soft_placement=True,
             #allow_growth=True,
             #device_count={'CPU': 0}
         )
     else:
         print('Using GPU')
         config = tf.ConfigProto(
             log_device_placement=True,
             allow_soft_placement=True,
             #allow_growth=True,
             #device_count={'GPU': 0}
         )
     if self.session_type == 'default':
         self.session = tf.Session(config=config)
     if self.session_type == 'interactive':
         self.session = tf.InteractiveSession(config=config)
     print('Session: ' + str(self.session))
     self.session.run(self.init_var)
     if self.tensorboard_logs is True:
         file_utils.mkdir_p(self.tensorboard_log_dir)
         self.merged_summary_op = tf.summary.merge_all()
         if self.restore is False:
             file_utils.delete_all_files_in_dir(self.tensorboard_log_dir)
         if self.separate_writer is False:
             self.summary_writer = tf.summary.FileWriter(
                 self.tensorboard_log_dir, graph=self.session.graph)
         else:
             self.train_writer = tf.summary.FileWriter(
                 self.tensorboard_log_dir + 'train',
                 graph=self.session.graph)
             if self.train_validate_split is not None:
                 self.validate_writer = tf.summary.FileWriter(
                     self.tensorboard_log_dir + 'validate',
                     graph=self.session.graph)
             if self.test_log is True:
                 self.test_writer = tf.summary.FileWriter(
                     self.tensorboard_log_dir + 'test',
                     graph=self.session.graph)
         if self.save_model is True:
             self.model = tf.train.Saver(max_to_keep=5)
     if self.train_validate_split is not None:
         train_data, validate_data, train_labels, validate_labels, train_classes, validate_classes = \
             train_test_split(data, labels, classes, train_size=self.train_validate_split)
         if self.verbose is True:
             print('Data shape: ' + str(data.shape))
             print('Labels shape: ' + str(labels.shape))
             print('Classes shape: ' + str(classes.shape))
             print('Train Data shape: ' + str(train_data.shape))
             print('Train Labels shape: ' + str(train_labels.shape))
             print('Train Classes shape: ' + str(train_classes.shape))
             print('Validate Data shape: ' + str(validate_data.shape))
             print('Validate Labels shape: ' + str(validate_labels.shape))
             print('Validate Classes shape: ' + str(validate_classes.shape))
         if self.test_log is False:
             self.optimize(train_data,
                           train_labels,
                           train_classes,
                           validate_data=validate_data,
                           validate_labels=validate_labels,
                           validate_classes=validate_classes)
         else:
             self.optimize(train_data,
                           train_labels,
                           train_classes,
                           validate_data=validate_data,
                           validate_labels=validate_labels,
                           validate_classes=validate_classes,
                           test_data=test_data,
                           test_labels=test_labels,
                           test_classes=test_classes)
     else:
         if self.test_log is False:
             self.optimize(data, labels, classes)
         else:
             self.optimize(data,
                           labels,
                           classes,
                           test_data=test_data,
                           test_labels=test_labels,
                           test_classes=test_classes)
示例#6
0
 def learn(self,
           train_data,
           train_labels,
           train_classes,
           validate_data=None,
           validate_labels=None,
           validate_classes=None,
           test_data=None,
           test_labels=None,
           test_classes=None):
     start = time.time()
     feed_dict_validate = {
         self.predict_params['input']: validate_data,
         self.predict_params['true_one_hot']: validate_labels,
         self.predict_params['true_class']: validate_classes
     }
     epoch = self.session.run(self.global_step)
     self.out_params['list_train_loss'] = self.session.run(
         self.params['var_train_loss']).tolist()
     self.out_params['list_train_acc'] = self.session.run(
         self.params['var_train_acc']).tolist()
     print('Length of train loss          : %d' %
           len(self.out_params['list_train_loss']))
     print('Length of train accuracy      : %d' %
           len(self.out_params['list_train_acc']))
     self.out_params['list_val_loss'] = self.session.run(
         self.params['var_val_loss']).tolist()
     self.out_params['list_val_acc'] = self.session.run(
         self.params['var_val_acc']).tolist()
     print('Length of validate loss       : %d' %
           len(self.out_params['list_val_loss']))
     print('Length of validate accuracy   : %d' %
           len(self.out_params['list_val_acc']))
     if self.test_log is True:
         self.out_params['list_test_acc'] = \
             self.session.run(self.params['var_test_acc']).tolist()
     print('Length of test accuracy       : %d' %
           len(self.out_params['list_test_acc']))
     print('Restoring training from epoch :', epoch)
     converged = False
     prev_cost = 0
     while (epoch != self.max_iterations) and converged is False:
         start = time.time()
         num_batches = int(math.ceil(train_data.shape[0] / self.batch_size))
         print('Training using original train data using batch size of %d '
               'and total batches of %d' % (self.batch_size, num_batches))
         start_batch_index = 0
         for batch in range(num_batches):
             end_batch_index = start_batch_index + self.batch_size
             if end_batch_index < train_data.shape[0]:
                 train_batch_data = train_data[
                     start_batch_index:end_batch_index, :]
                 train_batch_labels = train_labels[
                     start_batch_index:end_batch_index, :]
                 train_batch_classes = train_classes[
                     start_batch_index:end_batch_index]
             else:
                 train_batch_data = train_data[start_batch_index:, :]
                 train_batch_labels = train_labels[start_batch_index:, :]
                 train_batch_classes = train_classes[start_batch_index:]
             feed_dict_train = {
                 self.predict_params['input']: train_batch_data,
                 self.predict_params['true_one_hot']: train_batch_labels,
                 self.predict_params['true_class']: train_batch_classes
             }
             _, train_loss, train_loss_summary, \
             train_acc, train_acc_summary, curr_epoch \
                 = self.session.run([self.params['optimizer'],
                                     self.model_params['train_loss'], self.summary_params['train_loss'],
                                     self.model_params['train_acc'], self.summary_params['train_acc'],
                                     self.last_epoch],
                                    feed_dict=feed_dict_train)
             start_batch_index += self.batch_size
         if self.augmentation is not None:
             aug_data, aug_labels, aug_classes = \
                 self.make_data_augmentation(train_data, train_labels, train_classes)
             num_batches = int(
                 math.ceil(aug_data.shape[0] / self.batch_size))
             print(
                 'Training using augmented train data using batch size of %d '
                 'and total batches of %d' % (self.batch_size, num_batches))
             start_batch_index = 0
             for batch in range(num_batches):
                 end_batch_index = start_batch_index + self.batch_size
                 if end_batch_index < aug_data.shape[0]:
                     aug_batch_data = aug_data[
                         start_batch_index:end_batch_index, :]
                     aug_batch_labels = aug_labels[
                         start_batch_index:end_batch_index, :]
                     aug_batch_classes = aug_classes[
                         start_batch_index:end_batch_index]
                 else:
                     aug_batch_data = aug_data[start_batch_index:, :]
                     aug_batch_labels = aug_labels[start_batch_index:, :]
                     aug_batch_classes = aug_classes[start_batch_index:]
                 feed_dict_train = {
                     self.predict_params['input']: aug_batch_data,
                     self.predict_params['true_one_hot']: aug_batch_labels,
                     self.predict_params['true_class']: aug_batch_classes
                 }
                 _, train_loss, train_loss_summary, \
                 train_acc, train_acc_summary, curr_epoch \
                     = self.session.run([self.params['optimizer'],
                                         self.model_params['train_loss'], self.summary_params['train_loss'],
                                         self.model_params['train_acc'], self.summary_params['train_acc'],
                                         self.last_epoch],
                                        feed_dict=feed_dict_train)
                 start_batch_index += self.batch_size
         val_loss, val_loss_summary, val_acc, val_acc_summary = \
             self.session.run([self.model_params['val_loss'], self.summary_params['val_loss'],
                               self.model_params['val_acc'], self.summary_params['val_acc']],
                              feed_dict=feed_dict_validate)
         learn_rate_summary = self.session.run(
             self.summary_params['learn_rate'])
         self.out_params['list_train_loss'].append(train_loss)
         self.params['update_train_loss'] = tf.assign(
             self.params['var_train_loss'],
             self.out_params['list_train_loss'],
             validate_shape=False).eval()
         self.out_params['list_train_acc'].append(train_acc)
         self.params['update_train_acc'] = \
             tf.assign(self.params['var_train_acc'], self.out_params['list_train_acc'],
                       validate_shape=False).eval()
         self.out_params['list_learn_rate'].append(self.current_learn_rate)
         self.params['update_learn_rate'] = \
             tf.assign(self.params['var_learn_rate'], self.out_params['list_learn_rate'],
                       validate_shape=False).eval()
         self.out_params['list_val_loss'].append(val_loss)
         self.params['update_val_loss'] = \
             tf.assign(self.params['var_val_loss'], self.out_params['list_val_loss'],
                       validate_shape=False).eval()
         self.out_params['list_val_acc'].append(val_acc)
         self.params['update_val_acc'] = \
             tf.assign(self.params['var_val_acc'], self.out_params['list_val_acc'],
                       validate_shape=False).eval()
         self.summary_writer.add_summary(train_loss_summary, epoch)
         self.summary_writer.add_summary(train_acc_summary, epoch)
         self.summary_writer.add_summary(val_loss_summary, epoch)
         self.summary_writer.add_summary(val_acc_summary, epoch)
         self.summary_writer.add_summary(learn_rate_summary, epoch)
         if self.test_log is True:
             feed_dict_test = {
                 self.predict_params['input']: test_data,
                 self.predict_params['true_one_hot']: test_labels,
                 self.predict_params['true_class']: test_classes
             }
             test_acc, test_acc_summary = \
                 self.session.run([self.model_params['test_acc'],
                                   self.summary_params['test_acc']], feed_dict=feed_dict_test)
             self.out_params['list_test_acc'].append(test_acc)
             self.params['update_test_acc'] = tf.assign(
                 self.params['var_test_acc'],
                 self.out_params['list_test_acc'],
                 validate_shape=False).eval()
             self.summary_writer.add_summary(test_acc_summary, epoch)
         if epoch % self.display_step == 0:
             duration = time.time() - start
             if self.test_log is False:
                 print('>>> Epoch [%*d/%*d]' %
                       (int(len(str(self.max_iterations))), epoch,
                        int(len(str(
                            self.max_iterations))), self.max_iterations))
                 print(
                     'train_loss: %.4f | train_acc: %.4f | val_loss: %.4f | val_acc: %.4f | '
                     'Time: %.4f s' %
                     (train_loss, train_acc, val_loss, val_acc, duration))
             else:
                 print('>>> Epoch [%*d/%*d]' %
                       (int(len(str(self.max_iterations))), epoch,
                        int(len(str(
                            self.max_iterations))), self.max_iterations))
                 print(
                     'train_loss: %.4f | train_acc: %.4f | val_loss: %.4f | val_acc: %.4f | '
                     'test_acc: %.4f | Time: %.4f s' %
                     (train_loss, train_acc, val_loss, val_acc, test_acc,
                      duration))
         if self.save_checkpoint is True:
             model_directory = os.path.dirname(self.checkpoint_filename)
             file_utils.mkdir_p(model_directory)
             self.model.save(self.session,
                             self.checkpoint_filename,
                             global_step=epoch)
         if epoch == 0:
             prev_cost = train_loss
         else:
             if math.fabs(train_loss - prev_cost) < self.err_tolerance:
                 converged = False
         epoch += 1
     end = time.time()
     print('Fit completed in %.4f seconds' % (end - start))
     if self.save_model is True:
         print('Saving the graph to %s' %
               (self.logging_dir + self.model_name.split('/')[-1]))
         self.freeze_graph(self.logging_dir)
示例#7
0
 def learn(self, train_data, train_labels, train_classes,
           validate_data=None, validate_labels=None, validate_classes=None,
           test_data=None, test_labels=None, test_classes=None):
     start = time.time()
     feed_dict_train = {self.predict_params['input']: train_data,
                        self.predict_params['true_one_hot']: train_labels,
                        self.predict_params['true_class']: train_classes}
     feed_dict_validate = {self.predict_params['input']: validate_data,
                           self.predict_params['true_one_hot']: validate_labels,
                           self.predict_params['true_class']: validate_classes}
     epoch = self.session.run(self.global_step)
     self.out_params['list_train_loss'] = self.session.run(self.params['var_train_loss']).tolist()
     self.out_params['list_train_acc'] = self.session.run(self.params['var_train_acc']).tolist()
     print('Length of train loss          : %d' % len(self.out_params['list_train_loss']))
     print('Length of train accuracy      : %d' % len(self.out_params['list_train_acc']))
     self.out_params['list_val_loss'] = self.session.run(self.params['var_val_loss']).tolist()
     self.out_params['list_val_acc'] = self.session.run(self.params['var_val_acc']).tolist()
     print('Length of validate loss       : %d' % len(self.out_params['list_val_loss']))
     print('Length of validate accuracy   : %d' % len(self.out_params['list_val_acc']))
     if self.test_log is True:
         self.out_params['list_test_acc'] = \
             self.session.run(self.params['var_test_acc']).tolist()
     print('Length of test accuracy       : %d' % len(self.out_params['list_test_acc']))
     print('Restoring training from epoch :', epoch)
     converged = False
     prev_cost = 0
     while (epoch != self.max_iterations) and converged is False:
         start = time.time()
         _, train_loss, train_loss_summary, \
         train_acc, train_acc_summary, curr_epoch \
             = self.session.run([self.params['optimizer'],
                                 self.model_params['train_loss'], self.summary_params['train_loss'],
                                 self.model_params['train_acc'], self.summary_params['train_acc'],
                                 self.last_epoch],
                                feed_dict=feed_dict_train)
         val_loss, val_loss_summary, val_acc, val_acc_summary = \
             self.session.run([self.model_params['val_loss'], self.summary_params['val_loss'],
                               self.model_params['val_acc'], self.summary_params['val_acc']],
                              feed_dict=feed_dict_validate)
         learn_rate_summary = self.session.run(self.summary_params['learn_rate'])
         self.out_params['list_train_loss'].append(train_loss)
         self.params['update_train_loss'] = tf.assign(self.params['var_train_loss'],
                                                      self.out_params['list_train_loss'],
                                                      validate_shape=False).eval()
         self.out_params['list_train_acc'].append(train_acc)
         self.params['update_train_acc'] = \
             tf.assign(self.params['var_train_acc'], self.out_params['list_train_acc'],
                       validate_shape=False).eval()
         self.out_params['list_learn_rate'].append(self.current_learn_rate)
         self.params['update_learn_rate'] = \
             tf.assign(self.params['var_learn_rate'], self.out_params['list_learn_rate'],
                       validate_shape=False).eval()
         self.out_params['list_val_loss'].append(val_loss)
         self.params['update_val_loss'] = \
             tf.assign(self.params['var_val_loss'], self.out_params['list_val_loss'],
                       validate_shape=False).eval()
         self.out_params['list_val_acc'].append(val_acc)
         self.params['update_val_acc'] = \
             tf.assign(self.params['var_val_acc'], self.out_params['list_val_acc'],
                       validate_shape=False).eval()
         self.summary_writer.add_summary(train_loss_summary, epoch)
         self.summary_writer.add_summary(train_acc_summary, epoch)
         self.summary_writer.add_summary(val_loss_summary, epoch)
         self.summary_writer.add_summary(val_acc_summary, epoch)
         self.summary_writer.add_summary(learn_rate_summary, epoch)
         if self.test_log is True:
             feed_dict_test = {self.predict_params['input']: test_data,
                               self.predict_params['true_one_hot']: test_labels,
                               self.predict_params['true_class']: test_classes}
             test_acc, test_acc_summary = \
                 self.session.run([self.model_params['test_acc'],
                                   self.summary_params['test_acc']], feed_dict=feed_dict_test)
             self.out_params['list_test_acc'].append(test_acc)
             self.params['update_test_acc'] = tf.assign(self.params['var_test_acc'],
                                                        self.out_params['list_test_acc'],
                                                        validate_shape=False).eval()
             self.summary_writer.add_summary(test_acc_summary, epoch)
         if epoch % self.display_step == 0:
             duration = time.time() - start
             if self.test_log is False:
                 print('>>> Epoch [%*d/%*d]'
                       % (int(len(str(self.max_iterations))), epoch,
                          int(len(str(self.max_iterations))), self.max_iterations))
                 print('train_loss: %.4f | train_acc: %.4f | val_loss: %.4f | val_acc: %.4f | '
                       'Time: %.4f s' % (train_loss, train_acc, val_loss, val_acc, duration))
             else:
                 print('>>> Epoch [%*d/%*d]'
                       % (int(len(str(self.max_iterations))), epoch,
                          int(len(str(self.max_iterations))), self.max_iterations))
                 print('train_loss: %.4f | train_acc: %.4f | val_loss: %.4f | val_acc: %.4f | '
                       'test_acc: %.4f | Time: %.4f s'
                       % (train_loss, train_acc, val_loss, val_acc, test_acc, duration))
         if self.save_checkpoint is True:
             model_directory = os.path.dirname(self.checkpoint_filename)
             file_utils.mkdir_p(model_directory)
             self.model.save(self.session, self.checkpoint_filename, global_step=epoch)
         if epoch == 0:
             prev_cost = train_loss
         else:
             if math.fabs(train_loss - prev_cost) < self.err_tolerance:
                 converged = False
         epoch += 1
     end = time.time()
     print('Fit completed in %.4f seconds' % (end-start))
     if self.save_model is True:
         print('Saving the graph to %s' % (self.logging_dir+self.model_name.split('/')[-1]))
         self.freeze_graph(self.logging_dir)
示例#8
0
 def run(self,
         train_data,
         train_labels,
         train_classes,
         validate_data=None,
         validate_labels=None,
         validate_classes=None,
         test_data=None,
         test_labels=None,
         test_classes=None):
     feed_dict_train = {
         self.x: train_data,
         self.y_true: train_labels,
         self.y_true_cls: train_classes
     }
     if self.train_validate_split is not None:
         feed_dict_validate = {
             self.x: validate_data,
             self.y_true: validate_labels,
             self.y_true_cls: validate_classes
         }
     if self.test_log is True:
         feed_dict_test = {
             self.x: test_data,
             self.y_true: test_labels,
             self.y_true_cls: test_classes
         }
     epoch = self.session.run(self.global_step)
     self.list_train_loss = self.session.run(self.var_train_loss)
     self.list_train_loss = self.list_train_loss.tolist()
     self.list_train_acc = self.session.run(self.var_train_acc)
     self.list_train_acc = self.list_train_acc.tolist()
     print('Length of train loss          : %d' % len(self.list_train_loss))
     print('Length of train accuracy      : %d' % len(self.list_train_acc))
     if self.train_validate_split is not None:
         self.list_validate_loss = self.session.run(self.var_validate_loss)
         self.list_validate_loss = self.list_validate_loss.tolist()
         self.list_validate_acc = self.session.run(self.var_validate_acc)
         self.list_validate_acc = self.list_validate_acc.tolist()
     print('Length of validate loss       : %d' %
           len(self.list_validate_loss))
     print('Length of validate accuracy   : %d' %
           len(self.list_validate_acc))
     if self.test_log is True:
         self.list_test_acc = self.session.run(self.test_var_acc)
         self.list_test_acc = self.list_test_acc.tolist()
     print('Length of test accuracy       : %d' % len(self.list_test_acc))
     print('Restoring training from epoch :', epoch)
     converged = False
     prev_cost = 0
     num_batches = int(train_data.shape[0] / self.batch_size)
     while (epoch != self.max_iterations) and converged is False:
         start = time.time()
         start_batch_index = 0
         for batch in range(num_batches):
             # print('Training on batch %d' %batch)
             end_batch_index = start_batch_index + self.batch_size
             if end_batch_index < train_data.shape[0]:
                 train_batch_data = train_data[
                     start_batch_index:end_batch_index, :]
                 train_batch_labels = train_labels[
                     start_batch_index:end_batch_index, :]
                 train_batch_classes = train_classes[
                     start_batch_index:end_batch_index]
             else:
                 train_batch_data = train_data[start_batch_index:, :]
                 train_batch_labels = train_labels[start_batch_index:, :]
                 train_batch_classes = train_classes[start_batch_index:]
             feed_dict_train = {
                 self.x: train_batch_data,
                 self.y_true: train_batch_labels,
                 self.y_true_cls: train_batch_classes
             }
             _, train_loss, train_acc \
                 = self.session.run([self.optimizer, self.train_loss, self.train_accuracy], feed_dict=feed_dict_train)
             train_loss_summary = self.session.run(
                 self.train_loss_summary, feed_dict=feed_dict_train)
             train_acc_summary = self.session.run(self.train_acc_summary,
                                                  feed_dict=feed_dict_train)
             start_batch_index += self.batch_size
         curr_epoch = self.session.run(self.last_epoch)
         validate_loss_summary = self.session.run(
             self.validate_loss_summary, feed_dict=feed_dict_validate)
         learning_rate_summary = self.session.run(
             self.learning_rate_summary)
         self.list_train_loss.append(train_loss)
         self.update_train_loss = tf.assign(self.var_train_loss,
                                            self.list_train_loss,
                                            validate_shape=False)
         self.update_train_loss.eval()
         self.list_train_acc.append(train_acc)
         self.update_train_acc = tf.assign(self.var_train_acc,
                                           self.list_train_acc,
                                           validate_shape=False)
         self.update_train_acc.eval()
         self.list_learning_rate.append(self.current_learning_rate)
         self.update_learning_rate = tf.assign(self.var_learning_rate,
                                               self.list_learning_rate,
                                               validate_shape=False)
         self.update_learning_rate.eval()
         # w_hist = self.session.run(self.w_hist, feed_dict=feed_dict_train)
         # self.summary_writer.add_summary(w_hist, epoch)
         # w_im = self.session.run(self.w_im, feed_dict=feed_dict_train)
         # self.summary_writer.add_summary(w_im, epoch)
         if self.train_validate_split is not None:
             validate_loss, validate_acc, validate_acc_summary = \
             self.session.run([self.validate_loss, self.validate_accuracy, self.validate_acc_summary],
                              feed_dict=feed_dict_validate)
             self.list_validate_loss.append(validate_loss)
             self.update_validate_loss = tf.assign(self.var_validate_loss,
                                                   self.list_validate_loss,
                                                   validate_shape=False)
             self.update_validate_loss.eval()
             self.list_validate_acc.append(validate_acc)
             self.update_validate_acc = tf.assign(self.var_validate_acc,
                                                  self.list_validate_acc,
                                                  validate_shape=False)
             self.update_validate_acc.eval()
         if self.test_log is True:
             test_acc, test_acc_summary = \
                 self.session.run([self.test_accuracy, self.test_acc_summary], feed_dict=feed_dict_test)
             self.list_test_acc.append(test_acc)
             self.update_test_acc = tf.assign(self.test_var_acc,
                                              self.list_test_acc,
                                              validate_shape=False)
             self.update_test_acc.eval()
         if self.separate_writer is False:
             self.summary_writer.add_summary(train_loss_summary, epoch)
             self.summary_writer.add_summary(train_acc_summary, epoch)
             self.summary_writer.add_summary(validate_loss_summary, epoch)
             self.summary_writer.add_summary(validate_acc_summary, epoch)
             self.summary_writer.add_summary(test_acc_summary, epoch)
             self.summary_writer.add_summary(learning_rate_summary, epoch)
         else:
             self.train_writer.add_summary(train_loss_summary, epoch)
             self.train_writer.add_summary(train_acc_summary, epoch)
             if self.train_validate_split is not None:
                 self.validate_writer.add_summary(validate_loss_summary,
                                                  epoch)
                 self.validate_writer.add_summary(validate_acc_summary,
                                                  epoch)
             if self.test_log is True:
                 self.test_writer.add_summary(test_acc_summary, epoch)
         if epoch % self.display_step == 0:
             duration = time.time() - start
             if self.train_validate_split is not None and self.test_log is False:
                 print('>>> Epoch [%*d/%*d]' %
                       (int(len(str(self.max_iterations))), epoch,
                        int(len(str(
                            self.max_iterations))), self.max_iterations))
                 print(
                     'train_loss: %.4f | train_acc: %.4f | val_loss: %.4f | val_acc: %.4f | '
                     'Time: %.4f s' % (train_loss, train_acc, validate_loss,
                                       validate_acc, duration))
             elif self.train_validate_split is not None and self.test_log is True:
                 print('>>> Epoch [%*d/%*d]' %
                       (int(len(str(self.max_iterations))), epoch,
                        int(len(str(
                            self.max_iterations))), self.max_iterations))
                 print(
                     'train_loss: %.4f | train_acc: %.4f | val_loss: %.4f | val_acc: %.4f | '
                     'test_acc: %.4f | Time: %.4f s' %
                     (train_loss, train_acc, validate_loss, validate_acc,
                      test_acc, duration))
             elif self.train_validate_split is None and self.test_log is True:
                 print('>>> Epoch [%*d/%*d]' %
                       (int(len(str(self.max_iterations))), epoch,
                        int(len(str(
                            self.max_iterations))), self.max_iterations))
                 print(
                     'train_loss: %.4f | train_acc: %.4f | test_acc: %.4f | Time: %.4f s'
                     % (train_loss, train_acc, test_acc, duration))
             else:
                 print('>>> Epoch [%*d/%*d]' %
                       (int(len(str(self.max_iterations))), epoch,
                        int(len(str(
                            self.max_iterations))), self.max_iterations))
                 print('train_loss: %.4f | train_acc: %.4f | Time: %.4f s' %
                       (train_loss, train_acc, duration))
         if self.save_model is True:
             model_directory = os.path.dirname(self.model_name)
             file_utils.mkdir_p(model_directory)
             self.model.save(self.session,
                             self.model_name,
                             global_step=epoch)
         if epoch == 0:
             prev_cost = train_loss
         else:
             if math.fabs(train_loss - prev_cost) < self.tolerance:
                 converged = False
         epoch += 1