def _train_step(self, x, y, clw): """ Input: x - tensor of shape (bs, ps_h, ps_w, c_x) y - tensor of shape (bs, ps_h, ps_w, c_y) clw - cross_loss_weight, tensor of shape (bs, ps_h, ps_w, 1) """ with tf.GradientTape() as tape: x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self( [x, y], training=True) Kern = 1.0 - Degree_matrix(image_in_patches(x, 20), image_in_patches(y, 20)) kernels_loss = self.kernels_lambda * self.loss_object(Kern, ztz) l2_loss_k = sum(self._enc_x.losses) + sum(self._enc_y.losses) targets_k = (self._enc_x.trainable_variables + self._enc_y.trainable_variables) gradients_k = tape.gradient(kernels_loss + l2_loss_k, targets_k) if self.clipnorm is not None: gradients_k, _ = tf.clip_by_global_norm( gradients_k, self.clipnorm) self._optimizer_k.apply_gradients(zip(gradients_k, targets_k)) with tf.GradientTape() as tape: x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self( [x, y], training=True) l2_loss = (sum(self._enc_x.losses) + sum(self._enc_y.losses) + sum(self._dec_x.losses) + sum(self._dec_y.losses)) cycle_x_loss = self.cycle_lambda * self.loss_object(x, x_dot) cross_x_loss = self.cross_lambda * self.loss_object(y, y_hat, clw) recon_x_loss = self.recon_lambda * self.loss_object(x, x_tilde) cycle_y_loss = self.cycle_lambda * self.loss_object(y, y_dot) cross_y_loss = self.cross_lambda * self.loss_object(x, x_hat, clw) recon_y_loss = self.recon_lambda * self.loss_object(y, y_tilde) total_loss = (cycle_x_loss + cross_x_loss + recon_x_loss + cycle_y_loss + cross_y_loss + recon_y_loss + l2_loss) targets_all = (self._enc_x.trainable_variables + self._enc_y.trainable_variables + self._dec_x.trainable_variables + self._dec_y.trainable_variables) gradients_all = tape.gradient(total_loss, targets_all) if self.clipnorm is not None: gradients_all, _ = tf.clip_by_global_norm( gradients_all, self.clipnorm) self._optimizer_all.apply_gradients(zip(gradients_all, targets_all)) self.train_metrics["cycle_x"].update_state(cycle_x_loss) self.train_metrics["cross_x"].update_state(cross_x_loss) self.train_metrics["recon_x"].update_state(recon_x_loss) self.train_metrics["cycle_y"].update_state(cycle_y_loss) self.train_metrics["cross_y"].update_state(cross_y_loss) self.train_metrics["recon_y"].update_state(recon_y_loss) self.train_metrics["krnls"].update_state(kernels_loss) self.train_metrics["l2"].update_state(l2_loss) self.train_metrics["total"].update_state(total_loss)
def _train_step(self, x, y, clw): """ Input: x - tensor of shape (bs, ps_h, ps_w, c_x) y - tensor of shape (bs, ps_h, ps_w, c_y) clw - cross_loss_weight, tensor of shape (bs, ps_h, ps_w, 1) """ with tf.GradientTape() as tape: x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self( [x, y], training=True) if self.align_option in ["centre_crop", "center_crop"]: # Crop X % of pixels from the centre of the patch if self.krnl_width_x is None or self.krnl_width_y is None: Kern = 1.0 - Degree_matrix( tf.image.central_crop(x, self.centre_crop_frac), tf.image.central_crop(y, self.centre_crop_frac)) elif self.krnl_width_x is not None or self.krnl_width_y is not None: Kern = 1.0 - Degree_matrix_fixed_krnl( tf.image.central_crop(x, self.centre_crop_frac), tf.image.central_crop(y, self.centre_crop_frac), self.krnl_width_x, self.krnl_width_y) elif self.align_option in ["full", "no_crop"]: # Align code of entire patches - will cause memory issues if patches are too large if self.krnl_width_x is None or self.krnl_width_y is None: Kern = 1.0 - Degree_matrix(x, y) elif self.krnl_width_x is not None or self.krnl_width_y is not None: # Use global kernel size Kern = 1.0 - Degree_matrix_fixed_krnl( x, y, self.krnl_width_x, self.krnl_width_y) kernels_loss = self.kernels_lambda * self.loss_object(Kern, ztz) l2_loss_k = sum(self._enc_x.losses) + sum(self._enc_y.losses) targets_k = (self._enc_x.trainable_variables + self._enc_y.trainable_variables) gradients_k = tape.gradient(kernels_loss + l2_loss_k, targets_k) if self.clipnorm is not None: gradients_k, _ = tf.clip_by_global_norm( gradients_k, self.clipnorm) self._optimizer_k.apply_gradients(zip(gradients_k, targets_k)) with tf.GradientTape() as tape: x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self( [x, y], training=True) l2_loss = (sum(self._enc_x.losses) + sum(self._enc_y.losses) + sum(self._dec_x.losses) + sum(self._dec_y.losses)) cycle_x_loss = self.cycle_lambda * self.loss_object(x, x_dot) cross_x_loss = self.cross_lambda * self.loss_object(y, y_hat, clw) recon_x_loss = self.recon_lambda * self.loss_object(x, x_tilde) cycle_y_loss = self.cycle_lambda * self.loss_object(y, y_dot) cross_y_loss = self.cross_lambda * self.loss_object(x, x_hat, clw) recon_y_loss = self.recon_lambda * self.loss_object(y, y_tilde) total_loss = (cycle_x_loss + cross_x_loss + recon_x_loss + cycle_y_loss + cross_y_loss + recon_y_loss + l2_loss) targets_all = (self._enc_x.trainable_variables + self._enc_y.trainable_variables + self._dec_x.trainable_variables + self._dec_y.trainable_variables) gradients_all = tape.gradient(total_loss, targets_all) if self.clipnorm is not None: gradients_all, _ = tf.clip_by_global_norm( gradients_all, self.clipnorm) self._optimizer_all.apply_gradients(zip(gradients_all, targets_all)) self.train_metrics["cycle_x"].update_state(cycle_x_loss) self.train_metrics["cross_x"].update_state(cross_x_loss) self.train_metrics["recon_x"].update_state(recon_x_loss) self.train_metrics["cycle_y"].update_state(cycle_y_loss) self.train_metrics["cross_y"].update_state(cross_y_loss) self.train_metrics["recon_y"].update_state(recon_y_loss) self.train_metrics["krnls"].update_state(kernels_loss) self.train_metrics["l2"].update_state(l2_loss) self.train_metrics["total"].update_state(total_loss)