예제 #1
0
    def _train_step(self, x, y, clw):
        """
        Input:
        x - tensor of shape (bs, ps_h, ps_w, c_x)
        y - tensor of shape (bs, ps_h, ps_w, c_y)
        clw - cross_loss_weight, tensor of shape (bs, ps_h, ps_w, 1)
        """
        with tf.GradientTape() as tape:
            x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self(
                [x, y], training=True)

            Kern = 1.0 - Degree_matrix(image_in_patches(x, 20),
                                       image_in_patches(y, 20))
            kernels_loss = self.kernels_lambda * self.loss_object(Kern, ztz)
            l2_loss_k = sum(self._enc_x.losses) + sum(self._enc_y.losses)
            targets_k = (self._enc_x.trainable_variables +
                         self._enc_y.trainable_variables)
            gradients_k = tape.gradient(kernels_loss + l2_loss_k, targets_k)
            if self.clipnorm is not None:
                gradients_k, _ = tf.clip_by_global_norm(
                    gradients_k, self.clipnorm)

            self._optimizer_k.apply_gradients(zip(gradients_k, targets_k))

        with tf.GradientTape() as tape:
            x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self(
                [x, y], training=True)
            l2_loss = (sum(self._enc_x.losses) + sum(self._enc_y.losses) +
                       sum(self._dec_x.losses) + sum(self._dec_y.losses))
            cycle_x_loss = self.cycle_lambda * self.loss_object(x, x_dot)
            cross_x_loss = self.cross_lambda * self.loss_object(y, y_hat, clw)
            recon_x_loss = self.recon_lambda * self.loss_object(x, x_tilde)
            cycle_y_loss = self.cycle_lambda * self.loss_object(y, y_dot)
            cross_y_loss = self.cross_lambda * self.loss_object(x, x_hat, clw)
            recon_y_loss = self.recon_lambda * self.loss_object(y, y_tilde)

            total_loss = (cycle_x_loss + cross_x_loss + recon_x_loss +
                          cycle_y_loss + cross_y_loss + recon_y_loss + l2_loss)

            targets_all = (self._enc_x.trainable_variables +
                           self._enc_y.trainable_variables +
                           self._dec_x.trainable_variables +
                           self._dec_y.trainable_variables)

            gradients_all = tape.gradient(total_loss, targets_all)

            if self.clipnorm is not None:
                gradients_all, _ = tf.clip_by_global_norm(
                    gradients_all, self.clipnorm)
            self._optimizer_all.apply_gradients(zip(gradients_all,
                                                    targets_all))
        self.train_metrics["cycle_x"].update_state(cycle_x_loss)
        self.train_metrics["cross_x"].update_state(cross_x_loss)
        self.train_metrics["recon_x"].update_state(recon_x_loss)
        self.train_metrics["cycle_y"].update_state(cycle_y_loss)
        self.train_metrics["cross_y"].update_state(cross_y_loss)
        self.train_metrics["recon_y"].update_state(recon_y_loss)
        self.train_metrics["krnls"].update_state(kernels_loss)
        self.train_metrics["l2"].update_state(l2_loss)
        self.train_metrics["total"].update_state(total_loss)
예제 #2
0
    def _train_step(self, x, y, clw):
        """
        Input:
        x - tensor of shape (bs, ps_h, ps_w, c_x)
        y - tensor of shape (bs, ps_h, ps_w, c_y)
        clw - cross_loss_weight, tensor of shape (bs, ps_h, ps_w, 1)
        """
        with tf.GradientTape() as tape:
            x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self(
                [x, y], training=True)
            if self.align_option in ["centre_crop", "center_crop"]:
                # Crop X % of pixels from the centre of the patch
                if self.krnl_width_x is None or self.krnl_width_y is None:
                    Kern = 1.0 - Degree_matrix(
                        tf.image.central_crop(x, self.centre_crop_frac),
                        tf.image.central_crop(y, self.centre_crop_frac))
                elif self.krnl_width_x is not None or self.krnl_width_y is not None:
                    Kern = 1.0 - Degree_matrix_fixed_krnl(
                        tf.image.central_crop(x, self.centre_crop_frac),
                        tf.image.central_crop(y, self.centre_crop_frac),
                        self.krnl_width_x, self.krnl_width_y)
            elif self.align_option in ["full", "no_crop"]:
                # Align code of entire patches - will cause memory issues if patches are too large
                if self.krnl_width_x is None or self.krnl_width_y is None:
                    Kern = 1.0 - Degree_matrix(x, y)
                elif self.krnl_width_x is not None or self.krnl_width_y is not None:
                    # Use global kernel size
                    Kern = 1.0 - Degree_matrix_fixed_krnl(
                        x, y, self.krnl_width_x, self.krnl_width_y)

            kernels_loss = self.kernels_lambda * self.loss_object(Kern, ztz)
            l2_loss_k = sum(self._enc_x.losses) + sum(self._enc_y.losses)
            targets_k = (self._enc_x.trainable_variables +
                         self._enc_y.trainable_variables)
            gradients_k = tape.gradient(kernels_loss + l2_loss_k, targets_k)
            if self.clipnorm is not None:
                gradients_k, _ = tf.clip_by_global_norm(
                    gradients_k, self.clipnorm)

            self._optimizer_k.apply_gradients(zip(gradients_k, targets_k))

        with tf.GradientTape() as tape:
            x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self(
                [x, y], training=True)
            l2_loss = (sum(self._enc_x.losses) + sum(self._enc_y.losses) +
                       sum(self._dec_x.losses) + sum(self._dec_y.losses))
            cycle_x_loss = self.cycle_lambda * self.loss_object(x, x_dot)
            cross_x_loss = self.cross_lambda * self.loss_object(y, y_hat, clw)
            recon_x_loss = self.recon_lambda * self.loss_object(x, x_tilde)
            cycle_y_loss = self.cycle_lambda * self.loss_object(y, y_dot)
            cross_y_loss = self.cross_lambda * self.loss_object(x, x_hat, clw)
            recon_y_loss = self.recon_lambda * self.loss_object(y, y_tilde)

            total_loss = (cycle_x_loss + cross_x_loss + recon_x_loss +
                          cycle_y_loss + cross_y_loss + recon_y_loss + l2_loss)

            targets_all = (self._enc_x.trainable_variables +
                           self._enc_y.trainable_variables +
                           self._dec_x.trainable_variables +
                           self._dec_y.trainable_variables)

            gradients_all = tape.gradient(total_loss, targets_all)

            if self.clipnorm is not None:
                gradients_all, _ = tf.clip_by_global_norm(
                    gradients_all, self.clipnorm)
            self._optimizer_all.apply_gradients(zip(gradients_all,
                                                    targets_all))
        self.train_metrics["cycle_x"].update_state(cycle_x_loss)
        self.train_metrics["cross_x"].update_state(cross_x_loss)
        self.train_metrics["recon_x"].update_state(recon_x_loss)
        self.train_metrics["cycle_y"].update_state(cycle_y_loss)
        self.train_metrics["cross_y"].update_state(cross_y_loss)
        self.train_metrics["recon_y"].update_state(recon_y_loss)
        self.train_metrics["krnls"].update_state(kernels_loss)
        self.train_metrics["l2"].update_state(l2_loss)
        self.train_metrics["total"].update_state(total_loss)