def get_image_path_from(source_dir, target_dir): """ Get image paths from source and target directory """ source_images = extract_image_path([source_dir]) target_images = extract_image_path([target_dir]) assert len(source_images) == len( target_images), "Number of images in %r is not the same as %r" % ( source_dir, target_dir) return (source_images, target_images)
def train(self, x_path_dir, y_path_dir, epochs, train_steps, learning_rate, epochs_to_reduce_lr, reduce_lr, output_model, output_log, b_size): """ Train data """ # Check output directory # suffix for clafification on type if output_model: output_model+="AE" else: output_model+="MultiCNN" check_and_create_dir(output_model) # Load data x_filenames = extract_image_path([x_path_dir]) y_filenames = extract_image_path([y_path_dir]) # Scalar tf.summary.scalar('Learning rate', self.learning_rate) tf.summary.scalar('MSE', self.mse) tf.summary.scalar('MS SSIM', self.ssim) tf.summary.scalar('Loss', self.cost) tf.summary.image('BSE', self.Y) tf.summary.image('Ground truth', self.Y_clear) merged = tf.summary.merge_all() sess, saver = self.init_session() writer = tf.summary.FileWriter(output_log, sess.graph) l_rate = learning_rate try: for epoch_i in range(epochs): if ((epoch_i + 1) % epochs_to_reduce_lr) == 0: l_rate = l_rate * (1 - reduce_lr) if self.verbose: print("\n------------ Epoch : ",epoch_i+1) print("Current learning rate {}".format(l_rate)) # Training steps for i in range(train_steps): if self.verbose: print_train_steps(i+1, train_steps) x_batch, y_batch = get_batch(b_size, self.image_size, x_filenames, y_filenames) sess.run(self.optimizer, feed_dict={ self.X: x_batch, self.Y_clear: y_batch, self.learning_rate: l_rate, self.batch_size: b_size }) if i % 50 == 0: summary = sess.run(merged, {self.X: x_batch, self.Y_clear: y_batch, self.learning_rate: l_rate, self.batch_size: b_size}) writer.add_summary(summary, i+ epoch_i*train_steps) if self.verbose: print("\nSave model to {}".format(output_model)) saver.save(sess, output_model, global_step=(epoch_i+1)*train_steps) except KeyboardInterrupt: saver.save(sess, output_model)
def test(self, X_path, Y_path, save_output=False): ''' Test output of model on new batch of inputs versus ground truth saves output of model as image and saves difference between pixelwise difference as image ''' # produces list of file names from directory paths print(save_output, "print") X = extract_image_path([X_path]) Y = extract_image_path([X_path]) sess, _ = self.init_session() # scalar errors for each image MSE = np.zeros((len(X))) SSIM = np.zeros((len(X))) cost = np.zeros((len(X))) # combine images into one error image total_signed_error_image = np.zeros((self.image_size,self.image_size)) total_abs_error_image = np.zeros((self.image_size,self.image_size)) total_MSE_image = np.zeros((self.image_size,self.image_size)) # extracts and predicts one image at a time for i in range(len(X)): input_image = extract_n_normalize_image(X[i]) # dimensions are (batch size, height, width, RGB channels) x_image = np.reshape(np.array([input_image]), (1, self.image_size, self.image_size, 1)) y_image = np.reshape(np.array([input_image]), (1, self.image_size, self.image_size, 1)) output_image = sess.run(self.Y, feed_dict = {self.X: x_image}) y_truth = extract_n_normalize_image(Y[i]) # print(output_image.shape, "y_truth") # erases trivial dimensions from output y_pred = np.reshape(np.array([output_image]), (self.image_size, self.image_size)) error_image = y_truth-y_pred abs_error_image = np.absolute(error_image) squared_error_image = (np.square(error_image)) signed_error_image = error_image total_abs_error_image += abs_error_image total_signed_error_image += signed_error_image total_MSE_image += squared_error_image # SSIM[i] = tf.image.ssim_multiscale(Image.fromarray(y_image[0,:,:,0]), Image.fromarray(y_image[0,:,:,0]), 1) # SSIM[i] = tf.image.ssim_multiscale(y_image[0,:,:,0], y_image[0,:,:,0], 1) MSE[i] = np.average(squared_error_image) print(MSE[i]) cost[i] = self.alpha*SSIM[i] + (1 - self.alpha)*MSE[i] if save_output: imsave(save_output+str(i)+"_true.png", y_pred ) # imsave(save_output+str(i)+"_mse", mse) if save_output: imsave(save_output+"comb_MSE.png", total_MSE_image) imsave(save_output+"comb_SignedError.png", total_signed_error_image/leg) # return {"cost": np.average(cost),"MSE": np.average(MSE),"SSIM": np.average(SSIM),"MSE_IMAGE": total_MSE_image/len(X),"ABS_ERROR_IMAGE": abs_error_image/len(X), "SIGNED_ERROR_IMAGE": total_signed_error_image/len(X) } return {"cost": np.average(cost),"MSE": np.average(MSE),"MSE_IMAGE": total_MSE_image/len(X),"ABS_ERROR_IMAGE": abs_error_image/len(X), "SIGNED_ERROR_IMAGE": total_signed_error_image/len(X) }