def up_block(lower_res_inputs, same_res_inputs, output_channels, kernel_shape, stride=1, rate=1, num_convs=2, initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, regularizers=None, nonlinearity=tf.nn.relu, up_sampling_op=lambda x, size: tf.image.resize_images(x, size, method=tf.image.ResizeMethod.BILINEAR, align_corners=True), data_format='NCHW', name='up_block'): """A block made up of an up-sampling step followed by several convolutional layers.""" with tf.variable_scope(name): spatial_shape = same_res_inputs.get_shape()[2:] if data_format=='NHWC': features = up_sampling_op(lower_res_inputs, spatial_shape) features = tf.concat([features, same_res_inputs], axis=-1) else: lower_res_inputs = tf.transpose(lower_res_inputs, perm=[0,2,3,1]) features = up_sampling_op(lower_res_inputs, spatial_shape) features = tf.transpose(features, perm=[0,3,1,2]) features = tf.concat([features, same_res_inputs], axis=1) for _ in range(num_convs): features = snt.Conv2D(output_channels, kernel_shape, stride, rate, data_format=data_format, initializers=initializers, regularizers=regularizers)(features) features = nonlinearity(features) return features
def load(self, experiment_dir): config_filename = os.path.join(experiment_dir, CONFIG_FILENAME) cf = SourceFileLoader('cf', config_filename).load_module() self._punet = ProbUNet( latent_dim=cf.latent_dim, num_channels=cf.num_channels, num_1x1_convs=cf.num_1x1_convs, num_classes=cf.num_classes, num_convs_per_block=cf.num_convs_per_block, initializers={ 'w': training_utils.he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001) }, regularizers={ 'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0) }) self._x = tf.placeholder(tf.float32, shape=cf.network_input_shape) self._y = tf.placeholder(tf.uint8, shape=cf.label_shape) self._sigma_multiplier = tf.placeholder(tf.float32, shape=(6, )) # is_training=True to get posterior_net as well self._punet(self._x, self._y, is_training=True, one_hot_labels=cf.one_hot_labels) # posterior-based inference self._posterior_latent_mu_op = self._punet._q.mean() self._posterior_latent_stddev_op = self._punet._q.stddev() self._posterior_latent_sample_op = self._punet._q.sample() self._posterior_sample_det_op = self._punet.reconstruct( z_q=self._posterior_latent_mu_op, softmax=True) self._posterior_sample_op = self._punet.reconstruct( z_q=self._posterior_latent_sample_op, softmax=True) self._posterior_external_sample_op = self._punet.reconstruct( z_q=self._posterior_latent_mu_op + self._sigma_multiplier * self._posterior_latent_stddev_op, softmax=True) # prior-based inference self._prior_latent_mu_op = self._punet._p.mean() self._prior_latent_stddev_op = self._punet._p.stddev() self._prior_latent_sample_op = self._punet._p.sample() self._prior_sample_det_op = self._punet.reconstruct( z_q=self._prior_latent_mu_op, softmax=True) self._prior_sample_op = self._punet.reconstruct( z_q=self._prior_latent_sample_op, softmax=True) self._prior_external_sample_op = self._punet.reconstruct( z_q=self._prior_latent_mu_op + self._sigma_multiplier * self._prior_latent_stddev_op, softmax=True) self._sampled_logits_op = self._punet.sample() saver = tf.train.Saver(save_relative_paths=True) self._session = tf.train.MonitoredTrainingSession() print("Experiment dir:", experiment_dir) latest_ckpt_path = tf.train.latest_checkpoint(experiment_dir) print("Loading model from:", latest_ckpt_path) saver.restore(self._session, latest_ckpt_path)
def down_block(features, output_channels, kernel_shape, stride=1, rate=1, num_convs=2, initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, regularizers=None, nonlinearity=tf.nn.relu, down_sample_input=True, down_sampling_op=lambda x, df: tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df), data_format='NCHW', name='down_block'): """A block made up of a down-sampling step followed by several convolutional layers.""" with tf.variable_scope(name): if down_sample_input: features = down_sampling_op(features, data_format) for _ in range(num_convs): features = snt.Conv2D(output_channels, kernel_shape, stride, rate, data_format=data_format, initializers=initializers, regularizers=regularizers)(features) features = nonlinearity(features) return features
def __init__(self, latent_dim, num_channels, nonlinearity=tf.nn.relu, num_convs_per_block=3, initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)}, data_format='NCHW', down_sampling_op=lambda x, df:\ tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df), name="conv_dist"): self._latent_dim = latent_dim self._initializers = initializers self._regularizers = regularizers self._data_format = data_format if data_format == 'NCHW': self._channel_axis = 1 self._spatial_axes = [2,3] else: self._channel_axis = -1 self._spatial_axes = [1,2] super(AxisAlignedConvGaussian, self).__init__(name=name) with self._enter_variable_scope(): tf.logging.info('Building ConvGaussian.') self._encoder = VGG_Encoder(num_channels, nonlinearity, num_convs_per_block, initializers, regularizers, data_format=data_format, down_sampling_op=down_sampling_op)
def __init__(self, num_channels, num_classes, nonlinearity=tf.nn.relu, num_convs_per_block=3, initializers={ 'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001) }, regularizers={ 'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0) }, data_format='NCHW', up_sampling_op=lambda x, size: tf.image.resize_images( x, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True), name="vgg_dec"): super(VGG_Decoder, self).__init__(name=name) self._num_channels = num_channels self._num_classes = num_classes self._nonlinearity = nonlinearity self._num_convs = num_convs_per_block self._initializers = initializers self._regularizers = regularizers self._data_format = data_format self._up_sampling_op = up_sampling_op
def __init__( self, num_channels, nonlinearity=tf.nn.relu, num_convs_per_block=3, initializers={ 'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001) }, regularizers={ 'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0) }, data_format='NCHW', down_sampling_op=lambda x, df: tf.nn.avg_pool(x, ksize=[1, 1, 2, 2], strides=[1, 1, 2, 2], padding='SAME', data_format=df), name="vgg_enc"): super(VGG_Encoder, self).__init__(name=name) self._num_channels = num_channels self._nonlinearity = nonlinearity self._num_convs = num_convs_per_block self._initializers = initializers self._regularizers = regularizers self._data_format = data_format self._down_sampling_op = down_sampling_op
def down_block( features, output_channels, kernel_shape, stride=1, rate=1, num_convs=2, initializers={ 'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001) }, regularizers=None, nonlinearity=tf.nn.relu, down_sample_input=True, down_sampling_op=lambda x, size: tf.image.resize_images( x, size, method=tf.image.ResizeMethod.BILINEAR, align_corners=True ), data_format='NCHW', name='down_block'): """A block made up of a down-sampling step followed by several convolutional layers.""" with tf.variable_scope(name): if down_sample_input: features = down_sampling_op(features, data_format) for _ in range(num_convs): features = snt.Conv2D(output_channels, kernel_shape, stride, rate, data_format=data_format, initializers=initializers, regularizers=regularizers)(features) features = nonlinearity(features) return features
def __init__(self, latent_dim, num_channels, num_classes, num_1x1_convs=3, nonlinearity=tf.nn.relu, num_convs_per_block=3, initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)}, data_format='NCHW', down_sampling_op=lambda x, df:\ tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df), up_sampling_op=lambda x, size:\ tf.image.resize_images(x, size, method=tf.image.ResizeMethod.BILINEAR, align_corners=True), name='prob_unet'): super(ProbUNet, self).__init__(name=name) self._data_format = data_format self._num_classes = num_classes with self._enter_variable_scope(): self._unet = UNet(num_channels=num_channels, num_classes=num_classes, nonlinearity=nonlinearity, num_convs_per_block=num_convs_per_block, initializers=initializers, regularizers=regularizers, data_format=data_format, down_sampling_op=down_sampling_op, up_sampling_op=up_sampling_op) self._f_comb = Conv1x1Decoder(num_classes=num_classes, num_1x1_convs=num_1x1_convs, num_channels=num_channels[0], nonlinearity=nonlinearity, data_format=data_format, initializers=initializers, regularizers=regularizers) self._prior =\ AxisAlignedConvGaussian(latent_dim=latent_dim, num_channels=num_channels, nonlinearity=nonlinearity, num_convs_per_block=num_convs_per_block, initializers=initializers, regularizers=regularizers, name='prior') self._posterior =\ AxisAlignedConvGaussian(latent_dim=latent_dim, num_channels=num_channels, nonlinearity=nonlinearity, num_convs_per_block=num_convs_per_block, initializers=initializers, regularizers=regularizers, name='posterior')
def __init__(self, num_channels, num_classes, nonlinearity=tf.nn.relu, num_convs_per_block=3, initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, regularizers=None, data_format='NCHW', down_sampling_op=lambda x, df: tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df), up_sampling_op=lambda x, size: tf.image.resize_images(x, size, method=tf.image.ResizeMethod.BILINEAR, align_corners=True), name="unet"): super(UNet, self).__init__(name=name) with self._enter_variable_scope(): tf.logging.info('Building U-Net.') self._encoder = VGG_Encoder(num_channels, nonlinearity, num_convs_per_block, initializers, regularizers, data_format=data_format, down_sampling_op=down_sampling_op) self._decoder = VGG_Decoder(num_channels, num_classes, nonlinearity, num_convs_per_block, initializers, regularizers, data_format=data_format, up_sampling_op=up_sampling_op)
def write_test_predictions(cf): """ Write samples as numpy arrays. :param cf: config module :return: """ # do not use all gpus os.environ["CUDA_VISIBLE_DEVICES"] = cf.cuda_visible_devices data_dir = os.path.join(cf.data_dir, cf.resolution) data_dict = loadFiles(label_density=cf.label_density, split='val', input_path=data_dir, cities=None, instance=False) # prepare out_dir if not os.path.isdir(cf.out_dir): os.mkdir(cf.out_dir) logging.info('Writing to {}'.format(cf.out_dir)) # initialize computation graph prob_unet = ProbUNet(latent_dim=cf.latent_dim, num_channels=cf.num_channels, num_1x1_convs=cf.num_1x1_convs, num_classes=cf.num_classes, num_convs_per_block=cf.num_convs_per_block, initializers={ 'w': training_utils.he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001) }, regularizers={ 'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0) }) x = tf.placeholder(tf.float32, shape=cf.network_input_shape) with tf.device(cf.gpu_device): prob_unet(x, is_training=False, one_hot_labels=cf.one_hot_labels) sampled_logits = prob_unet.sample() saver = tf.train.Saver(save_relative_paths=True) with tf.train.MonitoredTrainingSession() as sess: print('EXP DIR', cf.exp_dir) latest_ckpt_path = tf.train.latest_checkpoint(cf.exp_dir) print('CKPT PATH', latest_ckpt_path) saver.restore(sess, latest_ckpt_path) for k, v in tqdm(data_dict.items()): img = np.load(v['data']) / 255. # add batch dimensions img = img[np.newaxis] for i in range(cf.num_samples): sample = sess.run(sampled_logits, feed_dict={x: img}) sample = np.argmax(sample, axis=1)[:, np.newaxis] sample = sample.astype(np.uint8) sample_path = os.path.join( cf.out_dir, '{}_sample{}_labelIds.npy'.format(k, i)) np.save(sample_path, sample)
def train(cf): """Perform training from scratch.""" # do not use all gpus os.environ["CUDA_VISIBLE_DEVICES"] = cf.cuda_visible_devices # initialize data providers data_provider = get_train_generators(cf) train_provider = data_provider['train'] val_provider = data_provider['val'] prob_unet = ProbUNet( latent_dim=cf.latent_dim, num_channels=cf.num_channels, num_1x1_convs=cf.num_1x1_convs, num_classes=cf.num_classes, num_convs_per_block=cf.num_convs_per_block, initializers={ 'w': training_utils.he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001) }, regularizers={'w': tf.contrib.layers.l2_regularizer(1.0)}) x = tf.placeholder(tf.float32, shape=cf.network_input_shape) y = tf.placeholder(tf.uint8, shape=cf.label_shape) mask = tf.placeholder(tf.uint8, shape=cf.loss_mask_shape) global_step = tf.train.get_or_create_global_step() if cf.learning_rate_schedule == 'piecewise_constant': learning_rate = tf.train.piecewise_constant(x=global_step, **cf.learning_rate_kwargs) else: learning_rate = tf.train.exponential_decay( learning_rate=cf.initial_learning_rate, global_step=global_step, **cf.learning_rate_kwargs) with tf.device(cf.gpu_device): prob_unet(x, y, is_training=True, one_hot_labels=cf.one_hot_labels) elbo = prob_unet.elbo(y, reconstruct_posterior_mean=cf.use_posterior_mean, beta=cf.beta, loss_mask=mask, analytic_kl=cf.analytic_kl, one_hot_labels=cf.one_hot_labels) reconstructed_logits = prob_unet._rec_logits sampled_logits = prob_unet.sample() reg_loss = cf.regularizarion_weight * tf.reduce_sum( tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) loss = -elbo + reg_loss rec_loss = prob_unet._rec_loss_mean kl = prob_unet._kl mean_val_rec_loss = tf.placeholder(tf.float32, shape=(), name="mean_val_rec_loss") mean_val_kl = tf.placeholder(tf.float32, shape=(), name="mean_val_kl") optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize( loss, global_step=global_step) # prepare tf summaries train_elbo_summary = tf.summary.scalar('train_elbo', elbo) train_kl_summary = tf.summary.scalar('train_kl', kl) train_rec_loss_summary = tf.summary.scalar('rec_loss', rec_loss) train_loss_summary = tf.summary.scalar('train_loss', loss) reg_loss_summary = tf.summary.scalar('train_reg_loss', reg_loss) lr_summary = tf.summary.scalar('learning_rate', learning_rate) beta_summary = tf.summary.scalar('beta', cf.beta) training_summary_op = tf.summary.merge([ train_loss_summary, reg_loss_summary, lr_summary, train_elbo_summary, train_kl_summary, train_rec_loss_summary, beta_summary ]) batches_per_second = tf.placeholder(tf.float32, shape=(), name="batches_per_sec_placeholder") timing_summary = tf.summary.scalar('batches_per_sec', batches_per_second) val_rec_loss_summary = tf.summary.scalar('val_loss', mean_val_rec_loss) val_kl_summary = tf.summary.scalar('val_kl', mean_val_kl) validation_summary_op = tf.summary.merge( [val_rec_loss_summary, val_kl_summary]) tf.global_variables_initializer() # Add ops to save and restore all the variables. saver_hook = tf.train.CheckpointSaverHook( checkpoint_dir=cf.exp_dir, save_steps=cf.save_every_n_steps, saver=tf.train.Saver(save_relative_paths=True)) # save config shutil.copyfile(cf.config_path, os.path.join(cf.exp_dir, 'used_config.py')) with tf.train.MonitoredTrainingSession(hooks=[saver_hook]) as sess: summary_writer = tf.summary.FileWriter(cf.exp_dir, sess.graph) logging.info('Model: {}'.format(cf.exp_dir)) for i in tqdm(range(cf.n_training_batches), disable=cf.disable_progress_bar): start_time = time.time() train_batch = next(train_provider) _, train_summary = sess.run( [optimizer, training_summary_op], feed_dict={ x: train_batch['data'], y: train_batch['seg'], mask: train_batch['loss_mask'] }) summary_writer.add_summary(train_summary, i) time_delta = time.time() - start_time train_speed = sess.run( timing_summary, feed_dict={batches_per_second: 1. / time_delta}) summary_writer.add_summary(train_speed, i) # validation if i % cf.validation['every_n_batches'] == 0: train_rec = sess.run(reconstructed_logits, feed_dict={ x: train_batch['data'], y: train_batch['seg'] }) image_path = os.path.join( cf.exp_dir, 'batch_{}_train_reconstructions.png'.format( i // cf.validation['every_n_batches'])) training_utils.plot_batch(train_batch, train_rec, num_classes=cf.num_classes, cmap=cf.color_map, out_dir=image_path) running_mean_val_rec_loss = 0. running_mean_val_kl = 0. for j in range(cf.validation['n_batches']): val_batch = next(val_provider) val_rec, val_sample, val_rec_loss, val_kl =\ sess.run([reconstructed_logits, sampled_logits, rec_loss, kl], feed_dict={x: val_batch['data'], y: val_batch['seg'], mask: val_batch['loss_mask']}) running_mean_val_rec_loss += val_rec_loss / cf.validation[ 'n_batches'] running_mean_val_kl += val_kl / cf.validation['n_batches'] if j == 0: image_path = os.path.join( cf.exp_dir, 'batch_{}_val_reconstructions.png'.format( i // cf.validation['every_n_batches'])) training_utils.plot_batch(val_batch, val_rec, num_classes=cf.num_classes, cmap=cf.color_map, out_dir=image_path) image_path = os.path.join( cf.exp_dir, 'batch_{}_val_samples.png'.format( i // cf.validation['every_n_batches'])) for _ in range(3): val_sample_ = sess.run(sampled_logits, feed_dict={ x: val_batch['data'], y: val_batch['seg'] }) val_sample = np.concatenate( [val_sample, val_sample_], axis=1) training_utils.plot_batch(val_batch, val_sample, num_classes=cf.num_classes, cmap=cf.color_map, out_dir=image_path) val_summary = sess.run(validation_summary_op, feed_dict={ mean_val_rec_loss: running_mean_val_rec_loss, mean_val_kl: running_mean_val_kl }) summary_writer.add_summary(val_summary, i) if cf.disable_progress_bar: logging.info('Evaluating epoch {}/{}: validation loss={}, kl={}'\ .format(i, cf.n_training_batches, running_mean_val_rec_loss, running_mean_val_kl)) sess.run(global_step)
def sample(cf, args): """Sampling from the learnt conditional distribution.""" sample_size = args.sample_size time_stamp = args.time_stamp ckpt_dir = os.path.join(cf.project_dir, 'experiments', time_stamp) sample_dir = os.path.join(cf.project_dir, 'samples', time_stamp) if not os.path.exists(sample_dir): os.mkdir(sample_dir) log_path = os.path.join(sample_dir, 'sampling_stat.log') logging.basicConfig(filename=log_path, level=logging.DEBUG, filemode='a') prl_dncnn = PRL(latent_dim=cf.latent_dim, output_channels=cf.output_channels, num_channels=cf.num_channels, det_net_depth=cf.det_net_depth, merging_depth=cf.merging_depth, num_convs_per_block=cf.num_convs_per_block, initializers={ 'w': training_utils.he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001) }, regularizers={'w': tf.contrib.layers.l2_regularizer(1.0)}, data_format=cf.data_format, name='prl_dncnn') x = tf.placeholder(tf.float32, shape=cf.network_input_shape, name='observation') y = tf.placeholder(tf.float32, shape=cf.network_output_shape, name='ground_truth') is_training = tf.placeholder(tf.bool) prl_dncnn(x, y, is_training, is_inference=True) sampled_imgs = prl_dncnn.inference_sample(x) saver = tf.train.Saver() [val_data_noisy_list, val_data_clean_list] = test_data_list(img_dir=cf.validation_data_dir, noise_type=cf.noise_type, noise_param=cf.noise_param, data_format=cf.data_format) num_data = len(val_data_clean_list) with tf.Session() as sess: saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir)) for i in tqdm(range(num_data)): restored_samples = [] sampling_start_time = time.time() for j in range(sample_size): smpl_img = sess.run(sampled_imgs, feed_dict={ x: val_data_noisy_list[i], is_training: False }) restored_samples.append(smpl_img) sampling_time_delta = time.time() - sampling_start_time restored_samples = np.asarray(restored_samples) if cf.data_format == 'NCHW': restored_samples = np.squeeze(restored_samples, axis=(1, 2)) else: restored_samples = np.squeeze(restored_samples, axis=(1, 4)) save_path = os.path.join( sample_dir, '{}_img{}_t{}_s{}.npy'.format(cf.validation_data_name, i, time_stamp, sample_size)) np.save(save_path, restored_samples) logging.info( '{}s used for image {} of sample size {}, average time: {} s/sample' .format(sampling_time_delta, i, sample_size, sampling_time_delta / sample_size)) val_noisy_path = os.path.join( sample_dir, '{}_val_noisy.npy'.format(cf.validation_data_name)) val_clean_path = os.path.join( sample_dir, '{}_val_clean.npy'.format(cf.validation_data_name)) for i in range(len(val_data_noisy_list)): if cf.data_format == 'NCHW': val_data_noisy_list[i] = np.squeeze(val_data_noisy_list[i], axis=(0, 1)) val_data_clean_list[i] = np.squeeze(val_data_clean_list[i], axis=(0, 1)) else: val_data_noisy_list[i] = np.squeeze(val_data_noisy_list[i], axis=(0, 3)) val_data_clean_list[i] = np.squeeze(val_data_clean_list[i], axis=(0, 3)) np.save(val_noisy_path, np.asarray(val_data_noisy_list)) np.save(val_clean_path, np.asarray(val_data_clean_list))
def train(cf): """Perform training of PRL with DnCNN from scratch.""" if cf.use_single_gpu: os.environ['CUDA_VISIBLE_DEVICES'] = cf.cuda_visible_devices train_dataset = train_generator(img_dir=cf.training_data_dir, data_format=cf.data_format, shuffle_every_n_epochs=cf.shuffle_every_n_epochs, batch_size=cf.batch_size, noise_type=cf.noise_type, noise_param=cf.noise_param) prl_dncnn = PRL(latent_dim=cf.latent_dim, output_channels=cf.output_channels, num_channels=cf.num_channels, det_net_depth=cf.det_net_depth, merging_depth=cf.merging_depth, num_convs_per_block=cf.num_convs_per_block, initializers={'w': training_utils.he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, regularizers={'w': tf.contrib.layers.l2_regularizer(1.0)}, data_format=cf.data_format, name='prl_dncnn') x = tf.placeholder(tf.float32, shape=cf.network_input_shape, name='observation') y = tf.placeholder(tf.float32, shape=cf.network_output_shape, name='ground_truth') beta = tf.placeholder(tf.float32, shape=(), name='beta') is_training = tf.placeholder(tf.bool) global_step = tf.train.get_or_create_global_step() if cf.learning_rate_schedule == 'piecewise_constant': learning_rate = tf.train.piecewise_constant(x=global_step, **cf.learning_rate_kwargs) else: learning_rate = tf.train.exponential_decay(learning_rate=cf.initial_learning_rate, global_step=global_step, **cf.learning_rate_kwargs) prl_dncnn(x, y, is_training, is_inference=False) model_loss = prl_dncnn.loss_mini_batch(x, y, beta=beta, analytic_kl=cf.analytic_kl, use_ref_mean=cf.use_ref_mean) reg_loss = cf.regularization_weight * tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) loss = model_loss + reg_loss # prepare for summaries ref_rec_loss = prl_dncnn._rec_loss kl = prl_dncnn._kl_val ref_sample = prl_dncnn._ref_sample inf_sample = prl_dncnn.inference_sample(x) gt_unnormed = training_utils.img_unnorm(y, cf.noise_type, cf.noise_param) ref_unnormed = training_utils.img_unnorm(ref_sample, cf.noise_type, cf.noise_param) inf_unnormed = training_utils.img_unnorm(inf_sample, cf.noise_type, cf.noise_param) if cf.data_format == 'NCHW': gt_unnormed = tf.transpose(gt_unnormed, perm=(0, 2, 3, 1)) ref_unnormed = tf.transpose(ref_unnormed, perm=(0, 2, 3, 1)) inf_unnormed = tf.transpose(inf_unnormed, perm=(0, 2, 3, 1)) # can be used for both training and validation # ref_avg_mse = tf.reduce_mean(tf.metrics.mean_squared_error(gt_unnormed, ref_unnormed)) # ref_avg_ssim = tf.reduce_mean(tf.image.ssim(gt_unnormed, ref_unnormed, max_val=cf.peak_val)) # ref_avg_psnr = tf.reduce_mean(tf.image.psnr(gt_unnormed, ref_unnormed, max_val=cf.peak_val)) inf_avg_mse = tf.reduce_mean(tf.metrics.mean_squared_error(gt_unnormed, inf_unnormed)) inf_avg_ssim = tf.reduce_mean(tf.image.ssim(gt_unnormed, inf_unnormed, max_val=cf.peak_val)) inf_avg_psnr = tf.reduce_mean(tf.image.psnr(gt_unnormed, inf_unnormed, max_val=cf.peak_val)) # ---------------------------- training summaries # loss summaries train_ref_rec_loss_summary = tf.summary.scalar('train_ref_rec_loss', ref_rec_loss) train_kl_summary = tf.summary.scalar('train_kl', kl) train_model_loss_summary = tf.summary.scalar('train_elbo', model_loss) train_reg_loss_summary = tf.summary.scalar('train_reg_loss', reg_loss) train_loss_summary = tf.summary.scalar('train_loss', loss) # quantitative indicator summaries (not in use during training, to save time) # train_ref_avg_mse_summary = tf.summary.scalar('train_reference_avg_mse', ref_avg_mse) # train_ref_avg_ssim_summary = tf.summary.scalar('train_reference_avg_ssim', ref_avg_ssim) # train_ref_avg_psnr_summary = tf.summary.scalar('train_reference_avg_psnr', ref_avg_psnr) # train_inf_avg_mse_summary = tf.summary.scalar('train_reference_avg_mse', inf_avg_mse) # train_inf_avg_ssim_summary = tf.summary.scalar('train_reference_avg_ssim', inf_avg_ssim) # train_inf_avg_psnr_summary = tf.summary.scalar('train_reference_avg_psnr', inf_avg_psnr) # hyper-parameter summaries lr_summary = tf.summary.scalar('learning_rate', learning_rate) beta_summary = tf.summary.scalar('kl_beta', beta) # merging summaries train_summary_op = tf.summary.merge([lr_summary, beta_summary, train_loss_summary, train_model_loss_summary, train_reg_loss_summary, train_ref_rec_loss_summary, train_kl_summary]) # ---------------------------- timing summaries batches_per_second = tf.placeholder(tf.float32, shape=(), name='batches_per_second') timing_summary = tf.summary.scalar('batches_per_sec', batches_per_second) # ---------------------------- validation summaries val_avg_ref_rec_loss = tf.placeholder(tf.float32, shape=(), name='mean_val_ref_rec_loss') val_avg_kl = tf.placeholder(tf.float32, shape=(), name='mean_val_kl') # val_avg_ref_mse = tf.placeholder(tf.float32, shape=(), name='mean_val_ref_mse') # val_avg_ref_ssim = tf.placeholder(tf.float32, shape=(), name='mean_val_ref_ssim') # val_avg_ref_psnr = tf.placeholder(tf.float32, shape=(), name='mean_val_ref_psnr') val_avg_inf_mse = tf.placeholder(tf.float32, shape=(), name='mean_val_inf_mse') val_avg_inf_ssim = tf.placeholder(tf.float32, shape=(), name='mean_val_inf_ssim') val_avg_inf_psnr = tf.placeholder(tf.float32, shape=(), name='mean_val_inf_psnr') val_ref_rec_loss_summary = tf.summary.scalar('validation_ref_rec_loss', val_avg_ref_rec_loss) val_kl_summary = tf.summary.scalar('valiation_kl', val_avg_kl) val_avg_inf_mse_summary = tf.summary.scalar('validation_avg_inf_mse', val_avg_inf_mse) val_avg_inf_ssim_summary = tf.summary.scalar('validation_avg_inf_ssim', val_avg_inf_ssim) val_avg_inf_psnr_summary = tf.summary.scalar('validation_avg_inf_psnr', val_avg_inf_psnr) validation_summary_op = tf.summary.merge([val_ref_rec_loss_summary, val_kl_summary, val_avg_inf_mse_summary, val_avg_inf_ssim_summary, val_avg_inf_psnr_summary]) update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_op): optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step=global_step) tf.global_variables_initializer() saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir=cf.experiment_dir, save_steps=cf.save_every_n_steps, saver=tf.train.Saver(save_relative_paths=True)) shutil.copyfile(cf.config_path, os.path.join(cf.experiment_dir, 'used_config.py')) with tf.train.MonitoredTrainingSession(hooks=[saver_hook]) as sess: summary_writer = tf.summary.FileWriter(cf.experiment_dir, sess.graph) logging.info('Model: {}'.format(cf.experiment_dir)) training_start_time = time.time() for i in tqdm(range(cf.num_training_batches), disable=cf.disable_progress_bar): sess_start_time = time.time() [train_data_noisy, train_data_clean] = next(train_dataset) _, train_summary = sess.run([optimizer, train_summary_op], feed_dict={x: train_data_noisy, y: train_data_clean, beta: cf.kl_weight, is_training: True}) summary_writer.add_summary(train_summary, i) sess_time_delta = time.time() - sess_start_time train_speed = sess.run(timing_summary, feed_dict={batches_per_second: 1. / sess_time_delta}) summary_writer.add_summary(train_speed, i) if i % cf.batches_per_epoch == 0: running_avg_val_ref_rec_loss = 0. running_avg_val_kl = 0. running_avg_val_inf_mse = 0. running_avg_val_inf_ssim = 0. running_avg_val_inf_psnr = 0. [val_data_noisy_list, val_data_clean_list] = test_data_list(img_dir=cf.validation_data_dir, noise_type=cf.noise_type, noise_param=cf.noise_param, data_format=cf.data_format) num_val_data = len(val_data_clean_list) val_ref_img_list = [] val_inf_img_list = [] for j in range(num_val_data): val_ref_img, val_inf_img, val_ref_rec_loss, val_kl, val_inf_mse, val_inf_ssim, val_inf_psnr = \ sess.run([ref_unnormed, inf_unnormed, ref_rec_loss, kl, inf_avg_mse, inf_avg_ssim, inf_avg_psnr], feed_dict={x: val_data_noisy_list[j], y: val_data_clean_list[j], beta: cf.kl_weight, is_training: False}) running_avg_val_ref_rec_loss += val_ref_rec_loss / num_val_data running_avg_val_kl += val_kl / num_val_data running_avg_val_inf_mse += val_inf_mse / num_val_data running_avg_val_inf_ssim += val_inf_ssim / num_val_data running_avg_val_inf_psnr += val_inf_psnr / num_val_data val_ref_img_list.append(val_ref_img) val_inf_img_list.append(val_inf_img) image_path = os.path.join(cf.experiment_image_dir, 'epoch_{}_val_samples.png'.format(i//cf.batches_per_epoch)) training_utils.save_sample_img(val_data_clean_list, val_ref_img_list, val_inf_img_list, img_path=image_path, noise_type=cf.noise_type, noise_param=cf.noise_param, colormap=cf.colormap) val_summary = sess.run(validation_summary_op, feed_dict={val_avg_ref_rec_loss: running_avg_val_ref_rec_loss, val_avg_kl: running_avg_val_kl, val_avg_inf_mse: running_avg_val_inf_mse, val_avg_inf_ssim: running_avg_val_inf_ssim, val_avg_inf_psnr: running_avg_val_inf_psnr}) summary_writer.add_summary(val_summary, i) if cf.disable_progress_bar: logging.info('Evaluating epoch {}/{}: validation loss={}, kl={}' .format(i, cf.num_training_batches, running_avg_val_ref_rec_loss, running_avg_val_kl)) sess.run(global_step) training_time_delta = time.time() - training_start_time logging.info('Total training time (with running time validations) is: %f' % training_time_delta)