def train(self, hparams_string): """ Run training of the network Args: Returns: """ args_train = hparams_parser_train(hparams_string) self.batch_size = args_train.batch_size self.epoch_max = args_train.epoch_max utils.save_model_configuration(args_train, self.dir_base) # Use dataset for loading in datasamples from .tfrecord (https://www.tensorflow.org/programmers_guide/datasets#consuming_tfrecord_data) # The iterator will get a new batch from the dataset each time a sess.run() is executed on the graph. dataset = tf.data.TFRecordDataset(self.dateset_filenames) dataset = dataset.map(util_data.decode_image) # decoding the tfrecord dataset = dataset.map( self._preProcessData) # potential local preprocessing of data dataset = dataset.shuffle(buffer_size=10000, seed=None) dataset = dataset.batch(batch_size=self.batch_size) iterator = dataset.make_initializable_iterator() inputs = iterator.get_next() # depends on self._preProcessData [in_image, in_label] = inputs # show network architecture utils.show_all_variables() # define model, loss, optimizer and summaries. outputs = self._create_inference(in_image) loss = self._create_losses(outputs, in_label) optimizer_op = self._create_optimizer(loss) summary_op = self._create_summaries(loss) with tf.Session() as sess: # Initialize all model Variables. sess.run(tf.global_variables_initializer()) # Create Saver object for loading and storing checkpoints saver = tf.train.Saver() # Create Writer object for storing graph and summaries for TensorBoard writer = tf.summary.FileWriter(self.dir_logs, sess.graph) # Reload Tensor values from latest checkpoint ckpt = tf.train.get_checkpoint_state(self.dir_checkpoints) epoch_start = 0 if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) ckpt_name = os.path.basename(ckpt.model_checkpoint_path) epoch_start = int(ckpt_name.split('-')[-1]) interationCnt = 0 # Do training loops for epoch_n in range(epoch_start, self.epoch_max): # Initiate or Re-initiate iterator sess.run(iterator.initializer) # Test model output before any training if epoch_n == 0: summary = sess.run(summary_op) writer.add_summary(summary, global_step=-1) utils.show_message( 'Running training epoch no: {0}'.format(epoch_n)) while True: try: _, summary = sess.run([optimizer_op, summary_op]) writer.add_summary(summary, global_step=interationCnt) counter = +1 except tf.errors.OutOfRangeError: # Do some evaluation after each Epoch break if epoch_n % 1 == 0: saver.save(sess, os.path.join(self.dir_checkpoints, self.model + '.model'), global_step=epoch_n)
def train(self, hparams_string): """ Run training of the network Args: Returns: """ args_train = hparams_parser_train(hparams_string) self.batch_size = args_train.batch_size self.epoch_max = args_train.epoch_max self.use_imagenet = args_train.use_imagenet self.model_version = args_train.model_version utils.save_model_configuration(args_train, self.dir_base) # Use dataset for loading in datasamples from .tfrecord (https://www.tensorflow.org/programmers_guide/datasets#consuming_tfrecord_data) # The iterator will get a new batch from the dataset each time a sess.run() is executed on the graph. dataset = tf.data.TFRecordDataset(self.dateset_filenames) dataset = dataset.map(util_data.decode_image) # decoding the tfrecord dataset = dataset.map( self._preProcessData) # potential local preprocessing of data dataset = dataset.shuffle(buffer_size=10000, seed=None) dataset = dataset.batch(batch_size=self.batch_size) iterator = dataset.make_initializable_iterator() input_getBatch = iterator.get_next() input_images = tf.placeholder(dtype=tf.float32, shape=[None] + self.image_dims, name='input_images') input_lbls = tf.placeholder(dtype=tf.float32, shape=[None, self.lbls_dim], name='input_lbls') # define model, loss, optimizer and summaries. output_logits = self._create_inference(input_images) loss = self._create_losses(output_logits, input_lbls) optimizer_op = self._create_optimizer(loss) summary_op = self._create_summaries(loss) # show network architecture utils.show_all_variables() if self.use_imagenet: if self.model_version == 'VGG16': path_imagenet_ckpt = os.path.join(self.dir_checkpoints, 'vgg_16.ckpt') if not tf.gfile.Exists(path_imagenet_ckpt): url_imagenet_model = "http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz" utils.download_and_uncompress_tarball( url_imagenet_model, self.dir_checkpoints) variables_to_restore = slim.get_model_variables('vgg_16') variables_to_restore = variables_to_restore[: -6] # ignore fc layers init_fn = slim.assign_from_checkpoint_fn( path_imagenet_ckpt, variables_to_restore) elif self.model_version == 'VGG19': path_imagenet_ckpt = os.path.join(self.dir_checkpoints, 'vgg_19.ckpt') if not tf.gfile.Exists(path_imagenet_ckpt): url_imagenet_model = "http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz" utils.download_and_uncompress_tarball( url_imagenet_model, self.dir_checkpoints) variables_to_restore = slim.get_model_variables('vgg_19') variables_to_restore = variables_to_restore[: -6] # ignore fc layers init_fn = slim.assign_from_checkpoint_fn( path_imagenet_ckpt, variables_to_restore) with tf.Session() as sess: # Initialize all model Variables. sess.run(tf.global_variables_initializer()) if self.use_imagenet: init_fn(sess) # Create Saver object for loading and storing checkpoints saver = tf.train.Saver() # Create Writer object for storing graph and summaries for TensorBoard writer = tf.summary.FileWriter(self.dir_logs, sess.graph) # Reload Tensor values from latest checkpoint ckpt = tf.train.get_checkpoint_state(self.dir_checkpoints) epoch_start = 0 if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) ckpt_name = os.path.basename(ckpt.model_checkpoint_path) epoch_start = int(ckpt_name.split('-')[-1]) interationCnt = 0 # Do training loops for epoch_n in range(epoch_start, self.epoch_max): # Initiate or Re-initiate iterator sess.run(iterator.initializer) # Test model output before any training # if epoch_n == 0: # summary_loss = sess.run(summary_op) # writer.add_summary(summary_loss, global_step=-1) utils.show_message( 'Running training epoch no: {0}'.format(epoch_n)) while True: try: image_batch, lbl_batch = sess.run(input_getBatch) _, summary_loss = sess.run([optimizer_op, summary_op], feed_dict={ input_images: image_batch, input_lbls: lbl_batch }) writer.add_summary(summary_loss, global_step=interationCnt) counter = +1 except tf.errors.OutOfRangeError: # Do some evaluation after each Epoch break if epoch_n % 1 == 0: saver.save(sess, os.path.join(self.dir_checkpoints, self.model + '.model'), global_step=epoch_n)
def train(self, hparams_string): """ Run training of the network Args: Returns: """ args_train = hparams_parser_train(hparams_string) self.batch_size = args_train.batch_size self.epoch_max = args_train.epoch_max self.unstructured_noise_dim = args_train.unstructured_noise_dim self.d_learning_rate = args_train.lr_discriminator self.g_learning_rate = args_train.lr_generator self.d_iter = args_train.d_iter self.n_testsamples = args_train.n_testsamples self.class_scale_d = args_train.class_scale_d self.class_scale_g = args_train.class_scale_g self.backup_frequency = args_train.backup_frequency utils.save_model_configuration(args_train, self.dir_base) # Use dataset for loading in datasamples from .tfrecord (https://www.tensorflow.org/programmers_guide/datasets#consuming_tfrecord_data) # The iterator will get a new batch from the dataset each time a sess.run() is executed on the graph. dataset = tf.data.TFRecordDataset(self.dateset_filenames) dataset = dataset.map(util_data.decode_image) # decoding the tfrecord dataset = dataset.map(self._genLatentCodes) dataset = dataset.shuffle(buffer_size = 10000, seed = None) dataset = dataset.batch(batch_size = self.batch_size) iterator = dataset.make_initializable_iterator() input_getBatch = iterator.get_next() # Create input placeholders input_images = tf.placeholder( dtype = tf.float32, shape = [None] + self.image_dims, name = 'input_images') input_lbls = tf.placeholder( dtype = tf.float32, shape = [None, self.lbls_dim], name = 'input_lbls') input_unstructured_noise = tf.placeholder( dtype = tf.float32, shape = [None, self.unstructured_noise_dim], name = 'input_unstructured_noise') input_test_lbls = tf.placeholder( dtype = tf.float32, shape = [self.n_testsamples * self.lbls_dim, self.lbls_dim], name = 'input_test_lbls') input_test_noise = tf.placeholder( dtype = tf.float32, shape = [self.n_testsamples * self.lbls_dim, self.unstructured_noise_dim], name = 'input_test_noise') # Define model, loss, optimizer and summaries. logits_source, logits_class, _ = self._create_inference(input_images, input_lbls, input_unstructured_noise) loss_discriminator, loss_generator = self._create_losses(logits_source, logits_class, input_lbls) train_op_discriminator, train_op_generator = self._create_optimizer(loss_discriminator, loss_generator) summary_op_dloss, summary_op_gloss, summary_op_img, summary_img = self._create_summaries(loss_discriminator, loss_generator, input_test_noise, input_test_lbls) # show network architecture utils.show_all_variables() # create constant test variable to inspect changes in the model test_noise, test_lbls = self._genTestInput(self.lbls_dim, n_samples = self.n_testsamples) dir_results_train = os.path.join(self.dir_results, 'Training') utils.checkfolder(dir_results_train) with tf.Session() as sess: # Initialize all model Variables. sess.run(tf.global_variables_initializer()) # Create Saver object for loading and storing checkpoints saver = tf.train.Saver() # Create Writer object for storing graph and summaries for TensorBoard writer = tf.summary.FileWriter(self.dir_logs, sess.graph) # Reload Tensor values from latest checkpoint ckpt = tf.train.get_checkpoint_state(self.dir_checkpoints) epoch_start = 0 if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) ckpt_name = os.path.basename(ckpt.model_checkpoint_path) epoch_start = int(ckpt_name.split('-')[-1]) + 1 interationCnt = 0 for epoch_n in range(epoch_start, self.epoch_max): # Test model output before any training if epoch_n == 0: summaryImg_tb, summaryImg = sess.run( [summary_op_img, summary_img], feed_dict={input_test_noise: test_noise, input_test_lbls: test_lbls}) writer.add_summary(summaryImg_tb, global_step=-1) utils.save_image_local(summaryImg, dir_results_train, 'Epoch_' + str(-1)) # Initiate or Re-initiate iterator sess.run(iterator.initializer) ### ---------------------------------------------------------- ### Update model print(datetime.datetime.now(),'- Running training epoch no:', epoch_n) while True: # for idx in range(0, num_batches): try: for _ in range(self.d_iter): image_batch, lbl_batch, unst_noise_batch = sess.run(input_getBatch) _, summary_dloss, _ = sess.run( [train_op_discriminator, summary_op_dloss], feed_dict={input_images: image_batch, input_lbls: lbl_batch, input_unstructured_noise: unst_noise_batch}) writer.add_summary(summary_dloss, global_step=interationCnt) _, summary_gloss = sess.run( [train_op_generator, summary_op_gloss], feed_dict={input_images: image_batch, input_lbls: lbl_batch, input_unstructured_noise: unst_noise_batch}) writer.add_summary(summary_gloss, global_step=interationCnt) interationCnt += 1 except tf.errors.OutOfRangeError: # Test current model summaryImg_tb, summaryImg = sess.run( [summary_op_img, summary_img], feed_dict={input_test_noise: test_noise, input_test_lbls: test_lbls}) writer.add_summary(summaryImg_tb, global_step=epoch_n) utils.save_image_local(summaryImg, dir_results_train, 'Epoch_' + str(epoch_n)) break # Save model variables to checkpoint if (epoch_n +1) % self.backup_frequency == 0: saver.save(sess,os.path.join(self.dir_checkpoints, self.model + '.model'), global_step=epoch_n)
def train(self, hparams_string): """ Run training of the network Args: Returns: """ args_train = hparams_parser_train(hparams_string) self.batch_size = args_train.batch_size self.epoch_max = args_train.epoch_max self.unstructured_noise_dim = args_train.unstructured_noise_dim self.info_var_dim = args_train.info_var_dim self.n_testsamples = args_train.n_testsamples self.d_learning_rate = args_train.lr_discriminator self.g_learning_rate = args_train.lr_generator self.d_iter = args_train.d_iter self.gp_lambda = args_train.gp_lambda self.class_scale_d = args_train.class_scale_d self.class_scale_g = args_train.class_scale_g self.info_scale_d = args_train.info_scale_d self.info_scale_g = args_train.info_scale_g self.backup_frequency = args_train.backup_frequency self.shards_idx_test = args_train.shards_idx_test utils.save_model_configuration(args_train, self.dir_base) # Create folder for saving training results dir_results_train = os.path.join(self.dir_results, 'Training') utils.checkfolder(dir_results_train) for class_n in range(self.lbls_dim): dir_result_train_class = dir_results_train + '/' + str( class_n).zfill(2) utils.checkfolder(dir_result_train_class) if 0 in self.shards_idx_test: dataset_filenames = self.dataset_filenames else: self.shards_idx_test = np.subtract(self.shards_idx_test, 1) shards_idx_training = np.delete(range(len(self.dataset_filenames)), self.shards_idx_test) dataset_filenames = [ self.dataset_filenames[i] for i in shards_idx_training ] utils.show_message('Training Data:') print(dataset_filenames) # Setup preprocessing pipeline preprocessing = preprocess_factory.preprocess_factory() # Dataset specific preprocessing if self.dataset == 'MNIST': pass elif self.dataset == 'PSD_Nonsegmented': pass elif self.dataset == 'PSD_Segmented': preprocessing.prep_pipe_from_string( "pad_to_size;{'height': 566, 'width': 566, 'constant': -1.0};random_rotation;{};crop_to_size;{'height': 400, 'width': 400};resize;{'height': 128, 'width': 128}" ) # Use dataset for loading in datasamples from .tfrecord (https://www.tensorflow.org/programmers_guide/datasets#consuming_tfrecord_data) # The iterator will get a new batch from the dataset each time a sess.run() is executed on the graph. dataset = tf.data.TFRecordDataset(dataset_filenames) dataset = dataset.shuffle(buffer_size=10000, seed=None) dataset = dataset.map(util_data.decode_image) # decoding the tfrecord dataset = dataset.map( self._genLatentCodes) # preprocess data and perform augmentation dataset = dataset.map(preprocessing.pipe) dataset = dataset.batch(batch_size=self.batch_size) iterator = dataset.make_initializable_iterator() input_getBatch = iterator.get_next() # Create input placeholders input_images = tf.placeholder(dtype=tf.float32, shape=[self.batch_size] + self.image_dims, name='input_images') input_lbls = tf.placeholder(dtype=tf.float32, shape=[None, self.lbls_dim], name='input_lbls') input_unstructured_noise = tf.placeholder( dtype=tf.float32, shape=[None, self.unstructured_noise_dim], name='input_unstructured_noise') input_info_noise = tf.placeholder(dtype=tf.float32, shape=[None, self.info_var_dim], name='input_info_noise') input_test_lbls = tf.placeholder(dtype=tf.float32, shape=[ self.n_testsamples**np.minimum( 2, self.info_var_dim), self.lbls_dim ], name='input_test_lbls') input_test_noise = tf.placeholder(dtype=tf.float32, shape=[ self.n_testsamples**np.minimum( 2, self.info_var_dim), self.unstructured_noise_dim ], name='input_test_noise') input_test_info_noise = tf.placeholder( dtype=tf.float32, shape=[ self.n_testsamples**np.minimum(2, self.info_var_dim), self.info_var_dim ], name='input_test_info_noise') # Define model, loss, optimizer and summaries. logits_source, logits_class, logits_info, artificial_images = self._create_inference( input_images, input_lbls, input_unstructured_noise, input_info_noise) loss_discriminator, loss_generator = self._create_losses( logits_source, logits_class, logits_info, artificial_images, input_lbls, input_info_noise) train_op_discriminator, train_op_generator = self._create_optimizer( loss_discriminator, loss_generator) summary_op_dloss, summary_op_gloss, summary_op_img, summary_img = self._create_summaries( loss_discriminator, loss_generator, input_test_noise, input_test_lbls, input_test_info_noise) # show network architecture utils.show_all_variables() # create constant test variable to inspect changes in the model self.combinations_info_var = itertools.combinations( range(self.info_var_dim), 2) self.combinations_info_var = list(self.combinations_info_var) test_noise, test_info = self._genTestInput() with tf.Session() as sess: # Initialize all model Variables. sess.run(tf.global_variables_initializer()) # Create Saver object for loading and storing checkpoints saver = tf.train.Saver(max_to_keep=500) # Create Writer object for storing graph and summaries for TensorBoard writer = tf.summary.FileWriter(self.dir_logs, sess.graph) # Reload Tensor values from latest checkpoint ckpt = tf.train.get_checkpoint_state(self.dir_checkpoints) epoch_start = 0 if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) ckpt_name = os.path.basename(ckpt.model_checkpoint_path) epoch_start = int(ckpt_name.split('-')[-1]) + 1 interationCnt = 0 for epoch_n in range(epoch_start, self.epoch_max): # Test model output before any training if epoch_n == 0: for class_n in range(self.lbls_dim): test_lbls = np.zeros([ self.n_testsamples**np.minimum( 2, self.info_var_dim), self.lbls_dim ]) test_lbls[:, class_n] = 1 for i in range(len(test_info)): test_info_combi = test_info[i] _, summaryImg = sess.run( [summary_op_img, summary_img], feed_dict={ input_test_noise: test_noise, input_test_lbls: test_lbls, input_test_info_noise: test_info_combi }) dir_result_train_class = dir_results_train + '/' + str( class_n).zfill(2) if self.info_var_dim < 2: filename_temp = 'Epoch_{0}_LatentVar_1'.format( epoch_n) else: filename_temp = 'Epoch_{0}_LatentCombi_{1}_{2}'.format( epoch_n, self.combinations_info_var[i][0], self.combinations_info_var[i][1]) # writer.add_summary(summaryImg_tb, global_step=epoch_n) utils.save_image_local(summaryImg, dir_result_train_class, filename_temp) # Initiate or Re-initiate iterator sess.run(iterator.initializer) ### ---------------------------------------------------------- ### Update model if (np.mod(epoch_n, 100) == 0) or epoch_n < 25: utils.show_message( 'Running training epoch no: {0}'.format(epoch_n)) while True: # for idx in range(0, num_batches): try: for _ in range(self.d_iter): image_batch, lbl_batch, unst_noise_batch, info_noise_batch = sess.run( input_getBatch) if (image_batch.shape[0] != self.batch_size): raise OutOfRangeError _, summary_dloss = sess.run( [train_op_discriminator, summary_op_dloss], feed_dict={ input_images: image_batch, input_lbls: lbl_batch, input_unstructured_noise: unst_noise_batch, input_info_noise: info_noise_batch }) writer.add_summary(summary_dloss, global_step=interationCnt) _, summary_gloss = sess.run( [train_op_generator, summary_op_gloss], feed_dict={ input_images: image_batch, input_lbls: lbl_batch, input_unstructured_noise: unst_noise_batch, input_info_noise: info_noise_batch }) writer.add_summary(summary_gloss, global_step=interationCnt) interationCnt += 1 except (tf.errors.OutOfRangeError, OutOfRangeError): # Test current model for class_n in range(self.lbls_dim): test_lbls = np.zeros([ self.n_testsamples**np.minimum( 2, self.info_var_dim), self.lbls_dim ]) test_lbls[:, class_n] = 1 for i in range(len(test_info)): test_info_combi = test_info[i] _, summaryImg = sess.run( [summary_op_img, summary_img], feed_dict={ input_test_noise: test_noise, input_test_lbls: test_lbls, input_test_info_noise: test_info_combi }) dir_result_train_class = dir_results_train + '/' + str( class_n).zfill(2) if self.info_var_dim < 2: filename_temp = 'Epoch_{0}_LatentVar_1'.format( epoch_n) else: filename_temp = 'Epoch_{0}_LatentCombi_{1}_{2}'.format( epoch_n, self.combinations_info_var[i][0], self.combinations_info_var[i][1]) # writer.add_summary(summaryImg_tb, global_step=epoch_n) utils.save_image_local(summaryImg, dir_result_train_class, filename_temp) break # Save model variables to checkpoint if (epoch_n + 1) % self.backup_frequency == 0: saver.save(sess, os.path.join(self.dir_checkpoints, self.model + '.model'), global_step=epoch_n)