def validation_step(model, data, losses): clean_image = data[0] noisy_image = data[1] dnet_output, snet_output = model(noisy_image) # clip the value so the values match the input (0,1) denoised_image = tf.clip_by_value(noisy_image - dnet_output[:, :, :, :3], 0, 1) losses['validation_mse'](mse_function(denoised_image, clean_image)) losses['validation_ssim'](ssim_function(denoised_image, clean_image)) losses['validation_psnr'](psnr_function(denoised_image, clean_image))
def train_step(model, optimizer, data, losses, clip_norms, radious): noisy_image = data[1] clean_image = data[0] sigma = data[2] epsilon = data[3] with tf.GradientTape() as tape: dnet_output, snet_output = model(noisy_image) log_alpha, log_beta, mean, m2, likelihood, guassian_kl, inverse_gamma_kl, loss = loss_function( dnet_output=dnet_output, snet_output=snet_output, noisy_image=noisy_image, clean_image=clean_image, sigma=sigma, epsilon=epsilon, radius=radious ) gradients = tape.gradient(loss, model.trainable_variables) clipped_gradients_dnet, dnet_new_norm, clipped_gradients_snet, snet_new_norm = clip_gradients( gradients, clip_norms ) optimizer.apply_gradients(zip(clipped_gradients_dnet + clipped_gradients_snet, model.trainable_variables)) denoised_image = tf.clip_by_value(noisy_image - dnet_output[:, :, :, :3], 0, 1) mse = mse_function(denoised_image, clean_image) psnr = psnr_function(denoised_image, clean_image) ssim = ssim_function(denoised_image, clean_image) losses['train_loss'](loss) losses['train_mse'](mse) losses['train_psnr'](psnr) losses['train_ssim'](ssim) return dnet_new_norm, snet_new_norm
def validation_step(model, clean_image, noisy_image, losses): """ Performs the validation step for every noisy image. Performs a forward pass to the model clips the denoised images and updates the loss functions. Parameters: model: The model to use for forward pass clean_image: The clean image to floating point format noisy_image: The noisy image to floating point format. losses: The mean value of all losses used for validation """ dnet_output, snet_output = model(noisy_image) # clip the value so the values match the input (0,1) denoised_image = tf.clip_by_value(noisy_image - dnet_output[:, :, :, :3], 0, 1) losses['validation_mse'](mse_function(denoised_image, clean_image)) losses['validation_ssim'](ssim_function(denoised_image, clean_image)) losses['validation_psnr'](psnr_function(denoised_image, clean_image)) return denoised_image