Пример #1
0
 def optimizers_config(self, mixed_precision=False, learning_rate=2e-4):
     self.G_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
     self.D_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
     if mixed_precision:
         self.G_optimizer = self.G_optimizer.get_mixed_precision()
         self.D_optimizer = self.D_optimizer.get_mixed_precision()
     return [self.G_optimizer, self.D_optimizer]
Пример #2
0
 def optimizers_config(self, mixed_precision=False, learning_rate=2e-4):
     self.G_optimizer = Adam(2e-4)
     self.D_optimizer = Adam(2e-4)
     if mixed_precision:
         self.G_optimizer = self.G_optimizer.get_mixed_precision()
         self.D_optimizer = self.D_optimizer.get_mixed_precision()
     return [self.G_optimizer, self.D_optimizer]
Пример #3
0
class CycleGAN(tf.keras.Model):
    """
    模型只负责给定训练集和测试(验证)集后的操作
    """
    def __init__(self,
                 train_set,
                 test_set,
                 loss_name="WGAN-GP",
                 mixed_precision=False,
                 learning_rate=2e-4,
                 tmp_path=None,
                 out_path=None):
        super(CycleGAN, self).__init__()
        #接收数据集和相关参数
        self.train_set = train_set
        self.test_set = test_set
        self.tmp_path = tmp_path
        self.out_path = out_path
        #定义模型
        self.G = networks.Generator(name="G_X2Y")
        self.F = networks.Generator(name="G_Y2X")
        if loss_name in ["WGAN-SN", "WGAN-GP-SN"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",
                                             use_sigmoid=False,
                                             sn=True)
            self.Dx = networks.Discriminator(name="If_is_real_X",
                                             use_sigmoid=False,
                                             sn=True)
            self.loss_name = loss_name[:-3]
        elif loss_name in ["WGAN", "WGAN-GP"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",
                                             use_sigmoid=False,
                                             sn=False)
            self.Dx = networks.Discriminator(name="If_is_real_X",
                                             use_sigmoid=False,
                                             sn=False)
            self.loss_name = loss_name
        elif loss_name in ["Vanilla", "LSGAN"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",
                                             use_sigmoid=True,
                                             sn=False)
            self.Dx = networks.Discriminator(name="If_is_real_X",
                                             use_sigmoid=True,
                                             sn=False)
            self.loss_name = loss_name
        else:
            raise ValueError("Do not support the loss " + loss_name)

        self.model_list = [self.G, self.F, self.Dy, self.Dx]
        #定义损失函数 优化器 记录等
        self.gan_loss = GanLoss(self.loss_name)
        self.optimizers_list = self.optimizers_config(
            mixed_precision=mixed_precision, learning_rate=learning_rate)
        self.mixed_precision = mixed_precision
        self.matrics_list = self.matrics_config()
        self.checkpoint_config()
        self.get_seed()

    def build(self, X_shape, Y_shape):
        """
        input_shape必须切片 因为在底层会被当做各层的输出shape而被改动
        """
        self.G.build(input_shape=X_shape[:])  #G X->Y
        self.Dy.build(input_shape=Y_shape[:])  #Dy Y or != Y
        self.F.build(input_shape=Y_shape[:])  #F Y->X
        self.Dx.build(input_shape=X_shape[:])  #Dx X or != X
        self.built = True

    def optimizers_config(self, mixed_precision=False, learning_rate=2e-4):
        self.G_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
        self.Dy_optimizer = Adam(learning_rate=4e-4, beta_1=0.0, beta_2=0.9)
        self.F_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
        self.Dx_optimizer = Adam(learning_rate=4e-4, beta_1=0.0, beta_2=0.9)
        if mixed_precision:
            self.G_optimizer = self.G_optimizer.get_mixed_precision()
            self.Dy_optimizer = self.Dy_optimizer.get_mixed_precision()
            self.F_optimizer = self.F_optimizer.get_mixed_precision()
            self.Dx_optimizer = self.Dx_optimizer.get_mixed_precision()
        return [
            self.G_optimizer, self.Dy_optimizer, self.F_optimizer,
            self.Dx_optimizer
        ]

    def matrics_config(self):
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_logdir = self.tmp_path + "/logs/" + current_time
        self.train_summary_writer = tf.summary.create_file_writer(train_logdir)
        self.m_psnr_X2Y = tf.keras.metrics.Mean('psnr_y', dtype=tf.float32)
        self.m_psnr_Y2X = tf.keras.metrics.Mean('psnr_x', dtype=tf.float32)
        self.m_ssim_X2Y = tf.keras.metrics.Mean('ssim_y', dtype=tf.float32)
        self.m_ssim_Y2X = tf.keras.metrics.Mean('ssim_x', dtype=tf.float32)
        return [
            self.m_psnr_X2Y, self.m_psnr_Y2X, self.m_ssim_X2Y, self.m_ssim_Y2X
        ]
        # return None
    def checkpoint_config(self):
        self.ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                        optimizer=self.optimizers_list,
                                        model=self.model_list,
                                        dataset=self.train_set)
        self.manager = tf.train.CheckpointManager(self.ckpt,
                                                  self.tmp_path + '/tf_ckpts',
                                                  max_to_keep=3)

    def pix_gradient(self, x):
        x = tf.reshape(x, shape=[1, 64, 64,
                                 1])  #在各batch和通道上进行像素梯度 对2D单通道而言其实没必要reshape
        dx, dy = tf.image.image_gradients(x)
        return dx, dy

    @tf.function(input_signature=[
        tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32),
        tf.TensorSpec(shape=[4], dtype=tf.int32),
        tf.TensorSpec(shape=[1], dtype=tf.uint32)
    ])
    def train_step_D(self, trainX, trainY, maskX, maskY, wgp_shape, step):
        with tf.GradientTape(persistent=True) as D_tape:
            GeneratedY = self.G(trainX)
            GeneratedY = tf.multiply(GeneratedY, maskY)
            Dy_real_out = self.Dy(trainY)
            Dy_fake_out = self.Dy(GeneratedY)

            GeneratedX = self.F(trainY)
            GeneratedX = tf.multiply(GeneratedX, maskX)
            Dx_real_out = self.Dx(trainX)
            Dx_fake_out = self.Dx(GeneratedX)

            e = tf.random.uniform(shape=wgp_shape, minval=0.0, maxval=1.0)
            mid_Y = e * trainY + (1 - e) * GeneratedY
            with tf.GradientTape() as gradient_penaltyY:
                gradient_penaltyY.watch(mid_Y)
                inner_loss = self.Dy(mid_Y)
            penalty = gradient_penaltyY.gradient(inner_loss, mid_Y)
            penalty_normY = 10.0 * tf.math.square(
                tf.norm(tf.reshape(penalty, shape=[wgp_shape[0], -1]),
                        ord=2,
                        axis=-1) - 1)

            e = tf.random.uniform(shape=wgp_shape, minval=0.0, maxval=1.0)
            mid_X = e * trainX + (1 - e) * GeneratedX
            with tf.GradientTape() as gradient_penaltyX:
                gradient_penaltyX.watch(mid_X)
                inner_loss = self.Dx(mid_X)
            penalty = gradient_penaltyX.gradient(inner_loss, mid_X)
            penalty_normX = 10.0 * tf.math.square(
                tf.norm(tf.reshape(penalty, shape=[wgp_shape[0], -1]),
                        ord=2,
                        axis=-1) - 1)

            Dy_loss = self.gan_loss.DiscriminatorLoss(
                Dy_real_out, Dy_fake_out) + tf.reduce_mean(penalty_normY)
            Dx_loss = self.gan_loss.DiscriminatorLoss(
                Dx_real_out, Dx_fake_out) + tf.reduce_mean(penalty_normX)

            if self.mixed_precision:
                scaled_Dy_loss = self.Dy_optimizer.get_scaled_loss(Dy_loss)
                scaled_Dx_loss = self.Dx_optimizer.get_scaled_loss(Dx_loss)

        if self.mixed_precision:
            scaled_gradients_of_Dy = D_tape.gradient(
                scaled_Dy_loss, self.Dy.trainable_variables)
            scaled_gradients_of_Dx = D_tape.gradient(
                scaled_Dx_loss, self.Dx.trainable_variables)
            gradients_of_Dy = self.Dy_optimizer.get_unscaled_gradients(
                scaled_gradients_of_Dy)
            gradients_of_Dx = self.Dx_optimizer.get_unscaled_gradients(
                scaled_gradients_of_Dx)
        else:
            gradients_of_Dy = D_tape.gradient(Dy_loss,
                                              self.Dy.trainable_variables)
            gradients_of_Dx = D_tape.gradient(Dx_loss,
                                              self.Dx.trainable_variables)

        self.Dy_optimizer.apply_gradients(
            zip(gradients_of_Dy, self.Dy.trainable_variables))
        self.Dx_optimizer.apply_gradients(
            zip(gradients_of_Dx, self.Dx.trainable_variables))
        return Dy_loss, Dx_loss

    @tf.function(input_signature=[
        tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32),
        tf.TensorSpec(shape=[4], dtype=tf.int32),
        tf.TensorSpec(shape=[1], dtype=tf.uint32)
    ])
    def train_step_G(self, trainX, trainY, maskX, maskY, wgp_shape, step):
        with tf.GradientTape(persistent=True) as G_tape:
            GeneratedY = self.G(trainX)
            GeneratedY = tf.multiply(GeneratedY, maskY)
            # Dy_real_out = self.Dy(trainY)
            Dy_fake_out = self.Dy(GeneratedY)

            GeneratedX = self.F(trainY)
            GeneratedX = tf.multiply(GeneratedX, maskX)
            # Dx_real_out = self.Dx(trainX)
            Dx_fake_out = self.Dx(GeneratedX)

            cycle_consistent_loss_X2Y = tf.reduce_mean(
                tf.abs(self.F(GeneratedY) - trainX))
            cycle_consistent_loss_Y2X = tf.reduce_mean(
                tf.abs(self.G(GeneratedX) - trainY))
            cycle_consistent = cycle_consistent_loss_X2Y + cycle_consistent_loss_Y2X

            if step >= 0:  #先不进行像素梯度和重建损失的使用
                cycle_l = 10.0
            else:
                cycle_l = 10.0
            G_loss = self.gan_loss.GeneratorLoss(
                Dy_fake_out) + cycle_l * (cycle_consistent)
            F_loss = self.gan_loss.GeneratorLoss(
                Dx_fake_out) + cycle_l * (cycle_consistent)

            if self.mixed_precision:
                scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss)
                scaled_F_loss = self.F_optimizer.get_scaled_loss(F_loss)
        if self.mixed_precision:
            scaled_gradients_of_G = G_tape.gradient(scaled_G_loss,
                                                    self.G.trainable_variables)
            scaled_gradients_of_F = G_tape.gradient(scaled_F_loss,
                                                    self.F.trainable_variables)
            gradients_of_G = self.G_optimizer.get_unscaled_gradients(
                scaled_gradients_of_G)
            gradients_of_F = self.F_optimizer.get_unscaled_gradients(
                scaled_gradients_of_F)

        else:
            gradients_of_G = G_tape.gradient(G_loss,
                                             self.G.trainable_variables)
            gradients_of_F = G_tape.gradient(F_loss,
                                             self.F.trainable_variables)

        self.G_optimizer.apply_gradients(
            zip(gradients_of_G, self.G.trainable_variables))
        self.F_optimizer.apply_gradients(
            zip(gradients_of_F, self.F.trainable_variables))
        return G_loss, F_loss

    @tf.function(input_signature=[
        tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32),
        tf.TensorSpec(shape=[4], dtype=tf.int32),
        tf.TensorSpec(shape=[1], dtype=tf.uint32)
    ])
    def train_step(self, trainX, trainY, maskX, maskY, wgp_shape, step):
        with tf.GradientTape(persistent=True) as cycle_type:
            GeneratedY = self.G(trainX)
            GeneratedY = tf.multiply(GeneratedY, maskY)
            Dy_real_out = self.Dy(trainY)
            Dy_fake_out = self.Dy(GeneratedY)

            GeneratedX = self.F(trainY)
            GeneratedX = tf.multiply(GeneratedX, maskX)
            Dx_real_out = self.Dx(trainX)
            Dx_fake_out = self.Dx(GeneratedX)

            cycle_consistent_loss_X2Y = tf.reduce_mean(
                tf.abs(self.F(GeneratedY) - trainX))
            cycle_consistent_loss_Y2X = tf.reduce_mean(
                tf.abs(self.G(GeneratedX) - trainY))
            cycle_consistent = cycle_consistent_loss_X2Y + cycle_consistent_loss_Y2X

            if step >= 0:  #先不进行像素梯度和重建损失的使用
                cycle_l = 10.0
            else:
                cycle_l = 10.0
            Dy_loss = self.gan_loss.DiscriminatorLoss(Dy_real_out, Dy_fake_out)
            Dx_loss = self.gan_loss.DiscriminatorLoss(Dx_real_out, Dx_fake_out)
            G_loss = self.gan_loss.GeneratorLoss(
                Dy_fake_out) + cycle_l * (cycle_consistent)
            F_loss = self.gan_loss.GeneratorLoss(
                Dx_fake_out) + cycle_l * (cycle_consistent)

        gradients_of_Dy = cycle_type.gradient(Dy_loss,
                                              self.Dy.trainable_variables)
        gradients_of_Dx = cycle_type.gradient(Dx_loss,
                                              self.Dx.trainable_variables)
        gradients_of_G = cycle_type.gradient(G_loss,
                                             self.G.trainable_variables)
        gradients_of_F = cycle_type.gradient(F_loss,
                                             self.F.trainable_variables)
        self.Dy_optimizer.apply_gradients(
            zip(gradients_of_Dy, self.Dy.trainable_variables))
        self.Dx_optimizer.apply_gradients(
            zip(gradients_of_Dx, self.Dx.trainable_variables))
        self.G_optimizer.apply_gradients(
            zip(gradients_of_G, self.G.trainable_variables))
        self.F_optimizer.apply_gradients(
            zip(gradients_of_F, self.F.trainable_variables))
        return G_loss, Dy_loss, F_loss, Dx_loss

    def train(self, epoches):
        self.ckpt.restore(self.manager.latest_checkpoint)
        for _ in range(epoches):
            start = time.time()
            for trainX, trainY, maskX, maskY in self.train_set:
                self.ckpt.step.assign_add(1)
                step = int(self.ckpt.step)
                if self.loss_name in ["WGAN", "WGAN-GP"]:
                    for __ in range(1):
                        Dy_loss, Dx_loss = self.train_step_D(
                            trainX, trainY, maskX, maskY,
                            tf.constant([trainY.shape[0], 1, 1, 1],
                                        shape=[4],
                                        dtype=tf.int32),
                            tf.constant(step, shape=[1], dtype=tf.uint32))
                    for __ in range(1):
                        G_loss, F_loss = self.train_step_G(
                            trainX, trainY, maskX, maskY,
                            tf.constant([trainY.shape[0], 1, 1, 1],
                                        shape=[4],
                                        dtype=tf.int32),
                            tf.constant(step, shape=[1], dtype=tf.uint32))
                elif self.loss_name in ["Vanilla", "LSGAN"]:
                    G_loss, Dy_loss, F_loss, Dx_loss = self.train_step(
                        trainX, trainY, maskX, maskY,
                        tf.constant([trainY.shape[0], 1, 1, 1],
                                    shape=[4],
                                    dtype=tf.int32),
                        tf.constant(step, shape=[1], dtype=tf.uint32))
                else:
                    raise ValueError("Inner Error")

                if step % 100 == 0:
                    save_path = self.manager.save()
                    print("Saved checkpoint for step {}: {}".format(
                        step, save_path))

                    self.G.save_weights(self.tmp_path +
                                        '/weights_saved/G.ckpt')
                    self.Dy.save_weights(self.tmp_path +
                                         '/weights_saved/Dy.ckpt')
                    self.F.save_weights(self.tmp_path +
                                        '/weights_saved/F.ckpt')
                    self.Dx.save_weights(self.tmp_path +
                                         '/weights_saved/Dx.ckpt')

                    self.wirte_summary(step=step,
                                       seed=self.seed,
                                       G=self.G,
                                       F=self.F,
                                       G_loss=G_loss,
                                       Dy_loss=Dy_loss,
                                       F_loss=F_loss,
                                       Dx_loss=Dx_loss,
                                       out_path=self.out_path)

                    print('Time to next 100 step {} is {} sec'.format(
                        step,
                        time.time() - start))
                    start = time.time()

    def get_seed(self):
        seed_get = iter(self.test_set)
        testX, testY, maskX, maskY = next(seed_get)
        print(testX.shape, testY.dtype, maskX.dtype, maskY.shape)
        plt.imshow(testX[0, :, :, 0], cmap='gray')
        plt.show()
        plt.imshow(testY[0, :, :, 0], cmap='gray')
        plt.show()
        plt.imshow(maskX[0, :, :, 0], cmap='gray')
        plt.show()
        plt.imshow(maskY[0, :, :, 0], cmap='gray')
        plt.show()
        self.seed = testX, testY, maskX, maskY

    def wirte_summary(self, step, seed, G, F, G_loss, Dy_loss, F_loss, Dx_loss,
                      out_path):
        testX, testY, maskX, maskY = seed
        GeneratedY = G(testX)
        GeneratedY = tf.multiply(GeneratedY, maskX)
        GeneratedX = F(testY)
        GeneratedX = tf.multiply(GeneratedX,
                                 maskY)  #测试时mask正好相反 因为只知道原来模态和原来模态的mask
        plt.figure(figsize=(5, 5))  #图片大一点才可以承载像素
        plt.subplot(2, 2, 1)
        plt.title('real X')
        plt.imshow(testX[0, :, :, 0], cmap='gray')
        plt.axis('off')
        plt.subplot(2, 2, 2)
        plt.title('fake Y')
        plt.imshow(GeneratedY[0, :, :, 0], cmap='gray')
        plt.axis('off')
        plt.subplot(2, 2, 3)
        plt.title('fake X')
        plt.imshow(GeneratedX[0, :, :, 0], cmap='gray')
        plt.axis('off')
        plt.subplot(2, 2, 4)
        plt.title('real Y')
        plt.imshow(testY[0, :, :, 0], cmap='gray')
        plt.axis('off')
        plt.savefig(out_path + '/image_at_{}.png'.format(step))
        plt.close()
        img = Image.open(out_path + '/image_at_{}.png'.format(step))
        img = tf.reshape(np.array(img), shape=(1, 500, 500, 4))

        with self.train_summary_writer.as_default():
            ##########################
            self.m_psnr_X2Y(tf.image.psnr(GeneratedY, testY, 1.0, name=None))
            self.m_psnr_Y2X(tf.image.psnr(GeneratedX, testX, 1.0, name=None))
            self.m_ssim_X2Y(
                tf.image.ssim(GeneratedY,
                              testY,
                              1,
                              filter_size=11,
                              filter_sigma=1.5,
                              k1=0.01,
                              k2=0.03))
            self.m_ssim_Y2X(
                tf.image.ssim(GeneratedX,
                              testX,
                              1,
                              filter_size=11,
                              filter_sigma=1.5,
                              k1=0.01,
                              k2=0.03))
            tf.summary.scalar('G_loss', G_loss, step=step)
            tf.summary.scalar('Dy_loss', Dy_loss, step=step)
            tf.summary.scalar('F_loss', F_loss, step=step)
            tf.summary.scalar('Dx_loss', Dx_loss, step=step)
            tf.summary.scalar('test_psnr_y',
                              self.m_psnr_X2Y.result(),
                              step=step)
            tf.summary.scalar('test_psnr_x',
                              self.m_psnr_Y2X.result(),
                              step=step)
            tf.summary.scalar('test_ssim_y',
                              self.m_ssim_X2Y.result(),
                              step=step)
            tf.summary.scalar('test_ssim_x',
                              self.m_ssim_Y2X.result(),
                              step=step)
            tf.summary.image("img", img, step=step)

        ##########################
        self.m_psnr_X2Y.reset_states()
        self.m_psnr_Y2X.reset_states()
        self.m_ssim_X2Y.reset_states()
        self.m_ssim_Y2X.reset_states()

    def test(self):
        self.ckpt.restore(self.manager.latest_checkpoint)
        step = 0
        black_board_X = np.zeros(shape=[240, 240], dtype=np.float32)
        black_board_Y = np.zeros(shape=[240, 240], dtype=np.float32)
        black_board_rX = np.zeros(shape=[240, 240], dtype=np.float32)
        black_board_rY = np.zeros(shape=[240, 240], dtype=np.float32)
        for i, (testX, testY, maskX, maskY) in enumerate(self.test_set):

            GeneratedY = self.G(testX)
            GeneratedY = tf.multiply(GeneratedY, maskX)
            GeneratedX = self.F(testY)
            GeneratedX = tf.multiply(GeneratedX,
                                     maskY)  #测试时mask正好相反 因为只知道原来模态和原来模态的mask
            if (i + 1) % 4 == 1:
                black_board_Y[48:175 + 1,
                              22:149 + 1] += GeneratedY.numpy()[0, :, :, 0]
                black_board_X[48:175 + 1,
                              22:149 + 1] += GeneratedX.numpy()[0, :, :, 0]
                black_board_rY[48:175 + 1,
                               22:149 + 1] += testY.numpy()[0, :, :, 0]
                black_board_rX[48:175 + 1,
                               22:149 + 1] += testX.numpy()[0, :, :, 0]
            elif (i + 1) % 4 == 2:
                black_board_Y[48:175 + 1,
                              90:217 + 1] += GeneratedY.numpy()[0, :, :, 0]
                black_board_X[48:175 + 1,
                              90:217 + 1] += GeneratedX.numpy()[0, :, :, 0]
                black_board_rY[48:175 + 1,
                               90:217 + 1] += testY.numpy()[0, :, :, 0]
                black_board_rX[48:175 + 1,
                               90:217 + 1] += testX.numpy()[0, :, :, 0]
            elif (i + 1) % 4 == 3:
                black_board_Y[64:191 + 1,
                              22:149 + 1] += GeneratedY.numpy()[0, :, :, 0]
                black_board_X[64:191 + 1,
                              22:149 + 1] += GeneratedX.numpy()[0, :, :, 0]
                black_board_rY[64:191 + 1,
                               22:149 + 1] += testY.numpy()[0, :, :, 0]
                black_board_rX[64:191 + 1,
                               22:149 + 1] += testX.numpy()[0, :, :, 0]
            elif (i + 1) % 4 == 0:
                black_board_Y[64:191 + 1,
                              90:217 + 1] += GeneratedY.numpy()[0, :, :, 0]
                black_board_X[64:191 + 1,
                              90:217 + 1] += GeneratedX.numpy()[0, :, :, 0]
                black_board_rY[64:191 + 1,
                               90:217 + 1] += testY.numpy()[0, :, :, 0]
                black_board_rX[64:191 + 1,
                               90:217 + 1] += testX.numpy()[0, :, :, 0]

                #norm
                black_board_Y[64:175 +
                              1, :] = black_board_Y[64:175 + 1, :] / 2.0
                black_board_Y[:,
                              90:149 + 1] = black_board_Y[:, 90:149 + 1] / 2.0
                black_board_X[64:175 +
                              1, :] = black_board_X[64:175 + 1, :] / 2.0
                black_board_X[:,
                              90:149 + 1] = black_board_X[:, 90:149 + 1] / 2.0
                black_board_rY[64:175 +
                               1, :] = black_board_rY[64:175 + 1, :] / 2.0
                black_board_rY[:, 90:149 +
                               1] = black_board_rY[:, 90:149 + 1] / 2.0
                black_board_rX[64:175 +
                               1, :] = black_board_rX[64:175 + 1, :] / 2.0
                black_board_rX[:, 90:149 +
                               1] = black_board_rX[:, 90:149 + 1] / 2.0

            else:
                raise ValueError("inner error")
            out_path = self.out_path
            if (i + 1) % 4 == 0:
                step += 1
                plt.figure(figsize=(10, 10))  #图片大一点才可以承载像素
                plt.subplot(2, 2, 1)
                plt.title('real X')
                plt.imshow(black_board_rX, cmap='gray')
                plt.axis('off')
                plt.subplot(2, 2, 2)
                plt.title('fake Y')
                plt.imshow(black_board_Y, cmap='gray')
                plt.axis('off')
                plt.subplot(2, 2, 3)
                plt.title('fake X')
                plt.imshow(black_board_X, cmap='gray')
                plt.axis('off')
                plt.subplot(2, 2, 4)
                plt.title('real Y')
                plt.imshow(black_board_rY, cmap='gray')
                plt.axis('off')
                plt.savefig(out_path + '/test/image_at_{}.png'.format(step))
                plt.close()
                img = Image.open(out_path +
                                 '/test/image_at_{}.png'.format(step))
                img = tf.reshape(np.array(img), shape=(1, 1000, 1000, 4))
                with self.train_summary_writer.as_default():
                    ##########################
                    black_board_Y = tf.reshape(tf.constant(black_board_Y,
                                                           dtype=tf.float32),
                                               shape=[1, 240, 240, 1])
                    black_board_X = tf.reshape(tf.constant(black_board_X,
                                                           dtype=tf.float32),
                                               shape=[1, 240, 240, 1])
                    black_board_rY = tf.reshape(tf.constant(black_board_rY,
                                                            dtype=tf.float32),
                                                shape=[1, 240, 240, 1])
                    black_board_rX = tf.reshape(tf.constant(black_board_rX,
                                                            dtype=tf.float32),
                                                shape=[1, 240, 240, 1])
                    self.m_psnr_X2Y(
                        tf.image.psnr(black_board_Y,
                                      black_board_rY,
                                      1.0,
                                      name=None))
                    self.m_psnr_Y2X(
                        tf.image.psnr(black_board_X,
                                      black_board_rX,
                                      1.0,
                                      name=None))
                    self.m_ssim_X2Y(
                        tf.image.ssim(black_board_Y,
                                      black_board_rY,
                                      1,
                                      filter_size=11,
                                      filter_sigma=1.5,
                                      k1=0.01,
                                      k2=0.03))
                    self.m_ssim_Y2X(
                        tf.image.ssim(black_board_X,
                                      black_board_rX,
                                      1,
                                      filter_size=11,
                                      filter_sigma=1.5,
                                      k1=0.01,
                                      k2=0.03))
                    tf.summary.scalar('test_psnr_y',
                                      self.m_psnr_X2Y.result(),
                                      step=step)
                    tf.summary.scalar('test_psnr_x',
                                      self.m_psnr_Y2X.result(),
                                      step=step)
                    tf.summary.scalar('test_ssim_y',
                                      self.m_ssim_X2Y.result(),
                                      step=step)
                    tf.summary.scalar('test_ssim_x',
                                      self.m_ssim_Y2X.result(),
                                      step=step)
                    tf.summary.image("img", img, step=step)

                ##########################
                self.m_psnr_X2Y.reset_states()
                self.m_psnr_Y2X.reset_states()
                self.m_ssim_X2Y.reset_states()
                self.m_ssim_Y2X.reset_states()
                black_board_X = np.zeros(shape=[240, 240], dtype=np.float32)
                black_board_Y = np.zeros(shape=[240, 240], dtype=np.float32)
                black_board_rX = np.zeros(shape=[240, 240], dtype=np.float32)
                black_board_rY = np.zeros(shape=[240, 240], dtype=np.float32)
Пример #4
0
class DCGAN(tf.keras.Model):
    """
    模型只负责给定训练集和测试(验证)集后的操作
    """
    def __init__(self,
                 train_set,
                 test_set,
                 loss_name="Vanilla",
                 mixed_precision=False,
                 learning_rate=2e-4,
                 tmp_path=None,
                 out_path=None):
        super(DCGAN, self).__init__()
        #接收数据集和相关参数
        self.train_set = train_set
        self.test_set = test_set
        self.tmp_path = tmp_path
        self.out_path = out_path
        #定义模型
        self.G = networks.Generator(name="G")
        if loss_name in ["WGAN-SN", "WGAN-GP-SN"]:
            self.D = networks.Discriminator(name="If_is_real",
                                            use_sigmoid=False,
                                            sn=True)
            self.loss_name = loss_name[:-3]
        elif loss_name in ["WGAN", "WGAN-GP"]:
            self.D = networks.Discriminator(name="If_is_real",
                                            use_sigmoid=False,
                                            sn=False)
            self.loss_name = loss_name
        elif loss_name in ["Vanilla", "LSGAN"]:
            self.D = networks.Discriminator(name="If_is_real",
                                            use_sigmoid=True,
                                            sn=False)
            self.loss_name = loss_name
        else:
            raise ValueError("Do not support the loss " + loss_name)

        self.model_list = [self.G, self.D]
        #定义损失函数 优化器 记录等
        self.gan_loss = GanLoss(self.loss_name)
        self.optimizers_list = self.optimizers_config(
            mixed_precision=mixed_precision, learning_rate=learning_rate)
        self.mixed_precision = mixed_precision
        self.matrics_list = self.matrics_config()
        self.checkpoint_config()
        self.get_seed()

    def build(self, input_shape_G, input_shape_D):
        """
        input_shape必须切片 因为在底层会被当做各层的输出shape而被改动
        """
        self.G.build(input_shape=input_shape_G[:])  #G X->Y
        self.D.build(input_shape=input_shape_D[:])  #D Y or not Y
        self.built = True

    def optimizers_config(self, mixed_precision=False, learning_rate=2e-4):
        self.G_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
        self.D_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
        if mixed_precision:
            self.G_optimizer = self.G_optimizer.get_mixed_precision()
            self.D_optimizer = self.D_optimizer.get_mixed_precision()
        return [self.G_optimizer, self.D_optimizer]

    def matrics_config(self):
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_logdir = self.tmp_path + "/logs/" + current_time
        self.train_summary_writer = tf.summary.create_file_writer(train_logdir)
        return []
        # return None
    def checkpoint_config(self):
        self.ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                        optimizer=self.optimizers_list,
                                        model=self.model_list,
                                        dataset=self.train_set)
        self.manager = tf.train.CheckpointManager(self.ckpt,
                                                  self.tmp_path + '/tf_ckpts',
                                                  max_to_keep=3)

    def pix_gradient(self, x):
        x = tf.reshape(x, shape=[1, 64, 64,
                                 1])  #在各batch和通道上进行像素梯度 对2D单通道而言其实没必要reshape
        dx, dy = tf.image.image_gradients(x)
        return dx, dy

    @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),\
                                  tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),\
                                  tf.TensorSpec(shape=[4],dtype=tf.int32),\
                                  tf.TensorSpec(shape=[1],dtype=tf.uint32)])
    def train_step_D(self, trainX, trainY, y_shape, step):
        with tf.GradientTape(persistent=True) as D_tape:
            GeneratedY = self.G(trainX)
            D_real_out = self.D(trainY)
            D_fake_out = self.D(GeneratedY)

            e = tf.random.uniform(shape=y_shape, minval=0.0, maxval=1.0)
            mid_Y = e * trainY + (1 - e) * GeneratedY
            with tf.GradientTape() as GP:
                GP.watch(mid_Y)
                inner_loss = self.D(mid_Y)
            penalty = GP.gradient(inner_loss, mid_Y)
            # print("penalty",penalty.shape)
            # penalty_norm = 10.0*tf.math.square(tf.maximum(tf.norm(penalty,ord='euclidean'),1.0)-1) #
            penalty_norm = 10.0 * tf.math.square(
                tf.norm(tf.reshape(penalty, [y_shape[0], -1]), axis=-1, ord=2)
                - 1)  #这是按照算法愿意
            # print("penalty_norm",penalty_norm.shape)
            D_loss = self.gan_loss.DiscriminatorLoss(
                D_real_out, D_fake_out) + tf.reduce_mean(penalty_norm)

            if self.mixed_precision:
                scaled_D_loss = self.D_optimizer.get_scaled_loss(D_loss)
        if self.mixed_precision:
            scaled_gradients_of_D = D_tape.gradient(scaled_D_loss,
                                                    self.D.trainable_variables)
            gradients_of_D = self.D_optimizer.get_unscaled_gradients(
                scaled_gradients_of_D)
        else:
            gradients_of_D = D_tape.gradient(D_loss,
                                             self.D.trainable_variables)

        self.D_optimizer.apply_gradients(
            zip(gradients_of_D, self.D.trainable_variables))

        return D_loss


    @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),\
                                  tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),\
                                  tf.TensorSpec(shape=[4],dtype=tf.int32),\
                                  tf.TensorSpec(shape=[1],dtype=tf.uint32)])
    def train_step_G(self, trainX, trainY, y_shape, step):
        with tf.GradientTape(persistent=True) as G_tape:
            GeneratedY = self.G(trainX)
            # Dy_real_out = self.Dy(trainY)
            D_fake_out = self.D(GeneratedY)

            G_loss = self.gan_loss.GeneratorLoss(D_fake_out)

            if self.mixed_precision:
                scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss)
        if self.mixed_precision:
            scaled_gradients_of_G = G_tape.gradient(scaled_G_loss,
                                                    self.G.trainable_variables)
            gradients_of_G = self.G_optimizer.get_unscaled_gradients(
                scaled_gradients_of_G)
        else:
            gradients_of_G = G_tape.gradient(G_loss,
                                             self.G.trainable_variables)

        self.G_optimizer.apply_gradients(
            zip(gradients_of_G, self.G.trainable_variables))
        return G_loss


    @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),\
                                  tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),\
                                  tf.TensorSpec(shape=[4],dtype=tf.int32),\
                                  tf.TensorSpec(shape=[1],dtype=tf.uint32)])
    def train_step(self, trainX, trainY, y_shape, step):
        with tf.GradientTape(persistent=True) as gan_type:
            GeneratedY = self.G(trainX)
            D_real_out = self.D(trainY)
            D_fake_out = self.D(GeneratedY)

            D_loss = self.gan_loss.DiscriminatorLoss(D_real_out, D_fake_out)
            G_loss = self.gan_loss.GeneratorLoss(D_fake_out)

            if self.mixed_precision:
                scaled_D_loss = self.D_optimizer.get_scaled_loss(D_loss)
                scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss)

        if self.mixed_precision:
            scaled_gradients_of_D = gan_type.gradient(
                scaled_D_loss, self.D.trainable_variables)
            scaled_gradients_of_G = gan_type.gradient(
                scaled_G_loss, self.G.trainable_variables)
            gradients_of_D = self.D_optimizer.get_unscaled_gradients(
                scaled_gradients_of_D)
            gradients_of_G = self.G_optimizer.get_unscaled_gradients(
                scaled_gradients_of_G)
        else:
            gradients_of_D = gan_type.gradient(D_loss,
                                               self.D.trainable_variables)
            gradients_of_G = gan_type.gradient(G_loss,
                                               self.G.trainable_variables)

        self.D_optimizer.apply_gradients(
            zip(gradients_of_D, self.D.trainable_variables))
        self.G_optimizer.apply_gradients(
            zip(gradients_of_G, self.G.trainable_variables))
        return D_loss, G_loss

    def train(self, epoches):
        self.ckpt.restore(self.manager.latest_checkpoint)
        for _ in range(epoches):
            start = time.time()
            for trainX, trainY in self.train_set:
                self.ckpt.step.assign_add(1)
                step = int(self.ckpt.step)
                if self.loss_name in ["WGAN", "WGAN-GP"]:
                    for __ in range(5):
                        D_loss = self.train_step_D(
                            trainX, trainY,
                            tf.constant([trainY.shape[0], 1, 1, 1],
                                        shape=[4],
                                        dtype=tf.int32),
                            tf.constant(step, shape=[1], dtype=tf.uint32))
                    for __ in range(1):
                        G_loss = self.train_step_G(
                            trainX, trainY,
                            tf.constant([trainY.shape[0], 1, 1, 1],
                                        shape=[4],
                                        dtype=tf.int32),
                            tf.constant(step, shape=[1], dtype=tf.uint32))
                elif self.loss_name in ["Vanilla", "LSGAN"]:
                    D_loss, G_loss = self.train_step(
                        trainX, trainY,
                        tf.constant([trainY.shape[0], 1, 1, 1],
                                    shape=[4],
                                    dtype=tf.int32),
                        tf.constant(step, shape=[1], dtype=tf.uint32))
                else:
                    raise ValueError("Inner Error")

                if step % 100 == 0:
                    save_path = self.manager.save()
                    print("Saved checkpoint for step {}: {}".format(
                        step, save_path))
                    self.G.save_weights(self.tmp_path +
                                        '/weights_saved/G.ckpt')
                    self.D.save_weights(self.tmp_path +
                                        '/weights_saved/D.ckpt')
                    self.wirte_summary(step=step,
                                       seed=self.seed,
                                       G=self.G,
                                       G_loss=G_loss,
                                       D_loss=D_loss,
                                       out_path=self.out_path)
                    print('Time to next 100 step {} is {} sec'.format(
                        step,
                        time.time() - start))
                    start = time.time()

    def test(self, take_nums):
        out_path = self.out_path + "/test"
        import os
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        self.ckpt.restore(self.manager.latest_checkpoint)
        seed_get = iter(self.test_set)
        for take in range(take_nums):
            plt.figure(figsize=(10, 10))  #图片大一点才可以承载像素
            for i in range(100):
                single_seed = next(seed_get)
                GeneratedY = self.G(single_seed, training=False)
                plt.subplot(10, 10, (i + 1))
                plt.imshow(GeneratedY[0, :, :, 0], cmap='gray')
                plt.axis('off')
            plt.savefig(out_path + '/image_at_{}.png'.format(take))
            plt.close()

    def get_seed(self):
        self.seed = []
        seed_get = iter(self.test_set)
        for _ in range(100):
            seed = next(seed_get)
            self.seed.append(seed)

    def wirte_summary(self, step, seed, G, G_loss, D_loss, out_path):
        plt.figure(figsize=(10, 10))  #图片大一点才可以承载像素
        for i, single_seed in enumerate(seed):
            GeneratedY = G(single_seed, training=False)
            plt.subplot(10, 10, (i + 1))
            plt.imshow(GeneratedY[0, :, :, 0], cmap='gray')
            plt.axis('off')
        plt.savefig(out_path + '/image_at_{}.png'.format(step))
        plt.close()
        img = Image.open(out_path + '/image_at_{}.png'.format(step))
        img = tf.reshape(np.array(img), shape=(1, 1000, 1000, 4))

        with self.train_summary_writer.as_default():
            ##########################
            tf.summary.scalar('G_loss', G_loss, step=step)
            tf.summary.scalar('D_loss', D_loss, step=step)
            tf.summary.image("img", img, step=step)
Пример #5
0
class CycleGAN(tf.keras.Model):
    """
    模型只负责给定训练集和测试(验证)集后的操作
    """
    def __init__(self,
                 train_set,
                 test_set,
                 loss_name="WGAN-GP",
                 mixed_precision=False,
                 learning_rate=2e-4,
                 tmp_path=None,
                 out_path=None):
        super(CycleGAN, self).__init__()
        #接收数据集和相关参数
        self.train_set = train_set
        self.test_set = test_set
        self.tmp_path = tmp_path
        self.out_path = out_path
        #定义模型
        self.G = networks.Generator(name="G_X2Y")
        self.F = networks.Generator(name="G_Y2X")
        if loss_name in ["WGAN-SN", "WGAN-GP-SN"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",
                                             use_sigmoid=False,
                                             sn=True)
            self.Dx = networks.Discriminator(name="If_is_real_X",
                                             use_sigmoid=False,
                                             sn=True)
            self.loss_name = loss_name[:-3]
        elif loss_name in ["WGAN", "WGAN-GP"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",
                                             use_sigmoid=False,
                                             sn=False)
            self.Dx = networks.Discriminator(name="If_is_real_X",
                                             use_sigmoid=False,
                                             sn=False)
            self.loss_name = loss_name
        elif loss_name in ["Vanilla-SN", "LSGAN-SN"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",
                                             use_sigmoid=True,
                                             sn=True)
            self.Dx = networks.Discriminator(name="If_is_real_X",
                                             use_sigmoid=True,
                                             sn=True)
            self.loss_name = loss_name[:-3]
        elif loss_name in ["Vanilla", "LSGAN"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",
                                             use_sigmoid=True,
                                             sn=False)
            self.Dx = networks.Discriminator(name="If_is_real_X",
                                             use_sigmoid=True,
                                             sn=False)
            self.loss_name = loss_name
        else:
            raise ValueError("Do not support the loss " + loss_name)
        self.vgg = vgg16()
        self.model_list = [self.G, self.F, self.Dy, self.Dx]
        #定义损失函数 优化器 记录等
        self.gan_loss = GanLoss(self.loss_name)
        self.optimizers_list = self.optimizers_config(
            mixed_precision=mixed_precision, learning_rate=learning_rate)
        self.mixed_precision = mixed_precision
        self.matrics_list = self.matrics_config()
        self.checkpoint_config()
        self.get_seed()

    def build(self, X_shape, Y_shape):
        """
        input_shape必须切片 因为在底层会被当做各层的输出shape而被改动
        """
        self.G.build(input_shape=X_shape[:])  #G X->Y
        self.Dy.build(input_shape=Y_shape[:])  #Dy Y or != Y
        self.F.build(input_shape=Y_shape[:])  #F Y->X
        self.Dx.build(input_shape=X_shape[:])  #Dx X or != X
        self.built = True

    def optimizers_config(self, mixed_precision=False, learning_rate=2e-4):
        self.G_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
        self.Dy_optimizer = Adam(learning_rate=4e-4, beta_1=0.0, beta_2=0.9)
        self.F_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9)
        self.Dx_optimizer = Adam(learning_rate=4e-4, beta_1=0.0, beta_2=0.9)
        # self.G_optimizer = Adam(learning_rate=2e-4)
        # self.Dy_optimizer = Adam(learning_rate=2e-4)
        # self.F_optimizer = Adam(learning_rate=2e-4)
        # self.Dx_optimizer = Adam(learning_rate=2e-4)
        if mixed_precision:
            self.G_optimizer = self.G_optimizer.get_mixed_precision()
            self.Dy_optimizer = self.Dy_optimizer.get_mixed_precision()
            self.F_optimizer = self.F_optimizer.get_mixed_precision()
            self.Dx_optimizer = self.Dx_optimizer.get_mixed_precision()
        return [
            self.G_optimizer, self.Dy_optimizer, self.F_optimizer,
            self.Dx_optimizer
        ]

    def matrics_config(self):
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_logdir = self.tmp_path + "/logs/" + current_time
        self.train_summary_writer = tf.summary.create_file_writer(train_logdir)
        self.m_psnr_X2Y = tf.keras.metrics.Mean('psnr_y', dtype=tf.float32)
        self.m_psnr_Y2X = tf.keras.metrics.Mean('psnr_x', dtype=tf.float32)
        self.m_ssim_X2Y = tf.keras.metrics.Mean('ssim_y', dtype=tf.float32)
        self.m_ssim_Y2X = tf.keras.metrics.Mean('ssim_x', dtype=tf.float32)
        return [
            self.m_psnr_X2Y, self.m_psnr_Y2X, self.m_ssim_X2Y, self.m_ssim_Y2X
        ]
        # return None

    def checkpoint_config(self):
        self.ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                        optimizer=self.optimizers_list,
                                        model=self.model_list,
                                        dataset=self.train_set)
        self.manager = tf.train.CheckpointManager(self.ckpt,
                                                  self.tmp_path + '/tf_ckpts',
                                                  max_to_keep=3)

    def pix_gradient(self, x):
        x = tf.reshape(x, shape=[1, 64, 64,
                                 1])  #在各batch和通道上进行像素梯度 对2D单通道而言其实没必要reshape
        dx, dy = tf.image.image_gradients(x)
        return dx, dy

    @tf.function(input_signature=[
        tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32)
    ])
    def train_step_D(self, trainX, trainY, maskX, maskY):
        with tf.GradientTape(persistent=True) as D_tape:
            GeneratedY = self.G(trainX)
            GeneratedY = tf.multiply(GeneratedY, maskY)
            Dy_real_out = self.Dy(trainY)
            Dy_fake_out = self.Dy(GeneratedY)

            GeneratedX = self.F(trainY)
            GeneratedX = tf.multiply(GeneratedX, maskX)
            Dx_real_out = self.Dx(trainX)
            Dx_fake_out = self.Dx(GeneratedX)

            e = tf.random.uniform(shape=self.wgp_shape, minval=0.0, maxval=1.0)
            mid_Y = e * trainY + (1 - e) * GeneratedY
            with tf.GradientTape() as gradient_penaltyY:
                gradient_penaltyY.watch(mid_Y)
                inner_loss = self.Dy(mid_Y)
            penalty = gradient_penaltyY.gradient(inner_loss, mid_Y)
            penalty_normY = 10.0 * tf.math.square(
                tf.norm(tf.reshape(penalty, shape=[self.wgp_shape[0], -1]),
                        ord=2,
                        axis=-1) - 1)

            e = tf.random.uniform(shape=self.wgp_shape, minval=0.0, maxval=1.0)
            mid_X = e * trainX + (1 - e) * GeneratedX
            with tf.GradientTape() as gradient_penaltyX:
                gradient_penaltyX.watch(mid_X)
                inner_loss = self.Dx(mid_X)
            penalty = gradient_penaltyX.gradient(inner_loss, mid_X)
            penalty_normX = 10.0 * tf.math.square(
                tf.norm(tf.reshape(penalty, shape=[self.wgp_shape[0], -1]),
                        ord=2,
                        axis=-1) - 1)

            Dy_loss = self.gan_loss.DiscriminatorLoss(
                Dy_real_out, Dy_fake_out) + tf.reduce_mean(penalty_normY)
            Dx_loss = self.gan_loss.DiscriminatorLoss(
                Dx_real_out, Dx_fake_out) + tf.reduce_mean(penalty_normX)

            if self.mixed_precision:
                scaled_Dy_loss = self.Dy_optimizer.get_scaled_loss(Dy_loss)
                scaled_Dx_loss = self.Dx_optimizer.get_scaled_loss(Dx_loss)

        if self.mixed_precision:
            scaled_gradients_of_Dy = D_tape.gradient(
                scaled_Dy_loss, self.Dy.trainable_variables)
            scaled_gradients_of_Dx = D_tape.gradient(
                scaled_Dx_loss, self.Dx.trainable_variables)
            gradients_of_Dy = self.Dy_optimizer.get_unscaled_gradients(
                scaled_gradients_of_Dy)
            gradients_of_Dx = self.Dx_optimizer.get_unscaled_gradients(
                scaled_gradients_of_Dx)
        else:
            gradients_of_Dy = D_tape.gradient(Dy_loss,
                                              self.Dy.trainable_variables)
            gradients_of_Dx = D_tape.gradient(Dx_loss,
                                              self.Dx.trainable_variables)

        self.Dy_optimizer.apply_gradients(
            zip(gradients_of_Dy, self.Dy.trainable_variables))
        self.Dx_optimizer.apply_gradients(
            zip(gradients_of_Dx, self.Dx.trainable_variables))
        return Dy_loss, Dx_loss

    @tf.function(input_signature=[
        tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32)
    ])
    def train_step_G(self, trainX, trainY, maskX, maskY):
        with tf.GradientTape(persistent=True) as G_tape:
            GeneratedY = self.G(trainX)
            GeneratedY = tf.multiply(GeneratedY, maskY)
            # Dy_real_out = self.Dy(trainY)
            Dy_fake_out = self.Dy(GeneratedY)

            GeneratedX = self.F(trainY)
            GeneratedX = tf.multiply(GeneratedX, maskX)
            # Dx_real_out = self.Dx(trainX)
            Dx_fake_out = self.Dx(GeneratedX)

            cycle_consistent_loss_X2Y = tf.reduce_mean(
                tf.abs(self.F(GeneratedY) - trainX))
            cycle_consistent_loss_Y2X = tf.reduce_mean(
                tf.abs(self.G(GeneratedX) - trainY))
            cycle_consistent = cycle_consistent_loss_X2Y + cycle_consistent_loss_Y2X

            fake_Y_perceptual = self.vgg(GeneratedY)
            real_Y_perceptual = self.vgg(trainY)
            fake_X_perceptual = self.vgg(GeneratedX)
            real_X_perceptual = self.vgg(trainX)
            reconstruction_loss_X2Y = 0
            reconstruction_loss_Y2X = 0
            for i in range(7):
                reconstruction_loss_X2Y += 0.14 * tf.reduce_mean(
                    tf.abs(fake_Y_perceptual[i] - real_Y_perceptual[i]))
                reconstruction_loss_Y2X += 0.14 * tf.reduce_mean(
                    tf.abs(fake_X_perceptual[i] - real_X_perceptual[i]))

            # reconstruction_loss_X2Y = tf.reduce_mean(tf.abs(GeneratedY-trainY))
            # reconstruction_loss_Y2X = tf.reduce_mean(tf.abs(GeneratedX-trainX))

            G_loss = self.gan_loss.GeneratorLoss(
                Dy_fake_out
            ) + 10.0 * cycle_consistent + reconstruction_loss_X2Y
            F_loss = self.gan_loss.GeneratorLoss(
                Dx_fake_out
            ) + 10.0 * cycle_consistent + reconstruction_loss_Y2X

            if self.mixed_precision:
                scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss)
                scaled_F_loss = self.F_optimizer.get_scaled_loss(F_loss)
        if self.mixed_precision:
            scaled_gradients_of_G = G_tape.gradient(scaled_G_loss,
                                                    self.G.trainable_variables)
            scaled_gradients_of_F = G_tape.gradient(scaled_F_loss,
                                                    self.F.trainable_variables)
            gradients_of_G = self.G_optimizer.get_unscaled_gradients(
                scaled_gradients_of_G)
            gradients_of_F = self.F_optimizer.get_unscaled_gradients(
                scaled_gradients_of_F)

        else:
            gradients_of_G = G_tape.gradient(G_loss,
                                             self.G.trainable_variables)
            gradients_of_F = G_tape.gradient(F_loss,
                                             self.F.trainable_variables)

        self.G_optimizer.apply_gradients(
            zip(gradients_of_G, self.G.trainable_variables))
        self.F_optimizer.apply_gradients(
            zip(gradients_of_F, self.F.trainable_variables))
        return G_loss, F_loss

    @tf.function(input_signature=[
        tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32),
        tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32)
    ])
    def train_step(self, trainX, trainY, maskX, maskY):
        with tf.GradientTape(persistent=True) as cycle_type:
            GeneratedY = self.G(trainX)
            GeneratedY = tf.multiply(GeneratedY, maskY)
            Dy_real_out = self.Dy(trainY)
            Dy_fake_out = self.Dy(GeneratedY)

            GeneratedX = self.F(trainY)
            GeneratedX = tf.multiply(GeneratedX, maskX)
            Dx_real_out = self.Dx(trainX)
            Dx_fake_out = self.Dx(GeneratedX)

            cycle_consistent_loss_X2Y = tf.reduce_mean(
                tf.abs(self.F(GeneratedY) - trainX))
            cycle_consistent_loss_Y2X = tf.reduce_mean(
                tf.abs(self.G(GeneratedX) - trainY))
            cycle_consistent = cycle_consistent_loss_X2Y + cycle_consistent_loss_Y2X

            fake_Y_perceptual = self.vgg(GeneratedY)
            real_Y_perceptual = self.vgg(trainY)
            fake_X_perceptual = self.vgg(GeneratedX)
            real_X_perceptual = self.vgg(trainX)
            reconstruction_loss_X2Y = 0
            reconstruction_loss_Y2X = 0
            for i in range(7):
                reconstruction_loss_X2Y += 0.14 * tf.reduce_mean(
                    tf.abs(fake_Y_perceptual[i] - real_Y_perceptual[i]))
                reconstruction_loss_Y2X += 0.14 * tf.reduce_mean(
                    tf.abs(fake_X_perceptual[i] - real_X_perceptual[i]))

            # reconstruction_loss_X2Y = tf.reduce_mean(tf.abs(GeneratedY-trainY))
            # reconstruction_loss_Y2X = tf.reduce_mean(tf.abs(GeneratedX-trainX))

            Dy_loss = self.gan_loss.DiscriminatorLoss(Dy_real_out, Dy_fake_out)
            Dx_loss = self.gan_loss.DiscriminatorLoss(Dx_real_out, Dx_fake_out)
            G_loss = self.gan_loss.GeneratorLoss(Dy_fake_out) + 10.0 * (
                cycle_consistent) + reconstruction_loss_X2Y
            F_loss = self.gan_loss.GeneratorLoss(Dx_fake_out) + 10.0 * (
                cycle_consistent) + reconstruction_loss_Y2X

        gradients_of_Dy = cycle_type.gradient(Dy_loss,
                                              self.Dy.trainable_variables)
        gradients_of_Dx = cycle_type.gradient(Dx_loss,
                                              self.Dx.trainable_variables)
        gradients_of_G = cycle_type.gradient(G_loss,
                                             self.G.trainable_variables)
        gradients_of_F = cycle_type.gradient(F_loss,
                                             self.F.trainable_variables)
        self.Dy_optimizer.apply_gradients(
            zip(gradients_of_Dy, self.Dy.trainable_variables))
        self.Dx_optimizer.apply_gradients(
            zip(gradients_of_Dx, self.Dx.trainable_variables))
        self.G_optimizer.apply_gradients(
            zip(gradients_of_G, self.G.trainable_variables))
        self.F_optimizer.apply_gradients(
            zip(gradients_of_F, self.F.trainable_variables))
        return G_loss, Dy_loss, F_loss, Dx_loss

    def train(self, epoches):
        self.ckpt.restore(self.manager.latest_checkpoint)
        my_step = int(self.ckpt.step)
        stop_flag = 0
        for _ in range(epoches):
            start = time.time()
            for trainX, trainY, maskX, maskY in self.train_set:
                my_step += 1
                self.ckpt.step.assign_add(1)
                step = int(self.ckpt.step)

                ###必要的超参数、变化的学习率都定义在这里
                self.l = 10.0 * (1.0 / (step * 0.1 + 1.0))
                self.wgp_shape = [trainY.shape[0], 1, 1, 1, 1]  #3D 多一个通道
                ###
                if self.loss_name in ["WGAN", "WGAN-GP"]:
                    for __ in range(1):
                        Dy_loss, Dx_loss = self.train_step_D(
                            trainX, trainY, maskX, maskY)
                    for __ in range(1):
                        G_loss, F_loss = self.train_step_G(
                            trainX, trainY, maskX, maskY)
                elif self.loss_name in ["Vanilla", "LSGAN"]:
                    G_loss, Dy_loss, F_loss, Dx_loss = self.train_step(
                        trainX, trainY, maskX, maskY)
                else:
                    raise ValueError("Inner Error")

                if step % 100 == 0:
                    save_path = self.manager.save()
                    print("Saved checkpoint for step {}: {}".format(
                        step, save_path))

                    self.G.save_weights(self.tmp_path +
                                        '/weights_saved/G.ckpt')
                    self.Dy.save_weights(self.tmp_path +
                                         '/weights_saved/Dy.ckpt')
                    self.F.save_weights(self.tmp_path +
                                        '/weights_saved/F.ckpt')
                    self.Dx.save_weights(self.tmp_path +
                                         '/weights_saved/Dx.ckpt')

                    self.wirte_summary(step=step,
                                       seed=self.seed,
                                       G=self.G,
                                       F=self.F,
                                       G_loss=G_loss,
                                       Dy_loss=Dy_loss,
                                       F_loss=F_loss,
                                       Dx_loss=Dx_loss,
                                       out_path=self.out_path)

                    print('Time to next 100 step {} is {} sec'.format(
                        step,
                        time.time() - start))
                    start = time.time()
                if step == 80200:
                    stop_flag = 1
                    break
            if stop_flag == 1:
                break

    def get_seed(self):
        seed_get = iter(self.test_set)
        testX, testY, maskX, maskY = next(seed_get)
        print(testX.shape, testY.dtype, maskX.dtype, maskY.shape)
        plt.imshow(testX[0, :, :, 1, 0], cmap='gray')
        plt.show()
        plt.imshow(testY[0, :, :, 1, 0], cmap='gray')
        plt.show()
        plt.imshow(maskX[0, :, :, 1, 0], cmap='gray')
        plt.show()
        plt.imshow(maskY[0, :, :, 1, 0], cmap='gray')
        plt.show()
        self.seed = testX, testY, maskX, maskY

    def wirte_summary(self, step, seed, G, F, G_loss, Dy_loss, F_loss, Dx_loss,
                      out_path):
        testX, testY, maskX, maskY = seed
        GeneratedY = G(testX)
        GeneratedY = tf.multiply(GeneratedY, maskX)
        GeneratedX = F(testY)
        GeneratedX = tf.multiply(GeneratedX, maskY)
        testX = testX[:, :, :, 1, :]
        testY = testY[:, :, :, 1, :]
        maskX = maskX[:, :, :, 1, :]
        maskY = maskY[:, :, :, 1, :]
        GeneratedY = GeneratedY[:, :, :, 1, :]
        GeneratedX = GeneratedX[:, :, :, 1, :]
        plt.figure(figsize=(5, 5))  #图片大一点才可以承载像素
        plt.subplot(2, 2, 1)
        plt.title('real X')
        plt.imshow(testX[0, :, :, 0], cmap='gray')
        plt.axis('off')
        plt.subplot(2, 2, 2)
        plt.title('fake Y')
        plt.imshow(GeneratedY[0, :, :, 0], cmap='gray')
        plt.axis('off')
        plt.subplot(2, 2, 3)
        plt.title('fake X')
        plt.imshow(GeneratedX[0, :, :, 0], cmap='gray')
        plt.axis('off')
        plt.subplot(2, 2, 4)
        plt.title('real Y')
        plt.imshow(testY[0, :, :, 0], cmap='gray')
        plt.axis('off')
        plt.savefig(out_path + '/image_at_{}.png'.format(step))
        plt.close()
        img = Image.open(out_path + '/image_at_{}.png'.format(step))
        img = tf.reshape(np.array(img), shape=(1, 500, 500, 4))

        with self.train_summary_writer.as_default():
            ##########################
            self.m_psnr_X2Y(tf.image.psnr(GeneratedY, testY, 1.0, name=None))
            self.m_psnr_Y2X(tf.image.psnr(GeneratedX, testX, 1.0, name=None))
            self.m_ssim_X2Y(
                tf.image.ssim(GeneratedY,
                              testY,
                              1,
                              filter_size=11,
                              filter_sigma=1.5,
                              k1=0.01,
                              k2=0.03))
            self.m_ssim_Y2X(
                tf.image.ssim(GeneratedX,
                              testX,
                              1,
                              filter_size=11,
                              filter_sigma=1.5,
                              k1=0.01,
                              k2=0.03))
            tf.summary.scalar('G_loss', G_loss, step=step)
            tf.summary.scalar('Dy_loss', Dy_loss, step=step)
            tf.summary.scalar('F_loss', F_loss, step=step)
            tf.summary.scalar('Dx_loss', Dx_loss, step=step)
            tf.summary.scalar('test_psnr_y',
                              self.m_psnr_X2Y.result(),
                              step=step)
            tf.summary.scalar('test_psnr_x',
                              self.m_psnr_Y2X.result(),
                              step=step)
            tf.summary.scalar('test_ssim_y',
                              self.m_ssim_X2Y.result(),
                              step=step)
            tf.summary.scalar('test_ssim_x',
                              self.m_ssim_Y2X.result(),
                              step=step)
            tf.summary.image("img", img, step=step)

        ##########################
        self.m_psnr_X2Y.reset_states()
        self.m_psnr_Y2X.reset_states()
        self.m_ssim_X2Y.reset_states()
        self.m_ssim_Y2X.reset_states()

    def test(self):
        buf = self.manager.latest_checkpoint
        buf = buf[:-3]
        for index, temp_point in enumerate(["800", "801", "802"]):
            self.ckpt.restore(buf + temp_point)
            step = 0
            out_path = self.out_path + "/test"
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            result_buf = []
            for i, (testX, testY, maskX, maskY) in enumerate(self.test_set):

                GeneratedY = self.G(testX)
                GeneratedY = tf.multiply(GeneratedY, maskX)
                GeneratedX = self.F(testY)
                GeneratedX = tf.multiply(
                    GeneratedX, maskY)  #测试时mask正好相反 因为只知道原来模态和原来模态的mask

                testX = testX[:, :, :, 1, :]
                testY = testY[:, :, :, 1, :]
                maskX = maskX[:, :, :, 1, :]
                maskY = maskY[:, :, :, 1, :]
                GeneratedY = GeneratedY[:, :, :, 1, :]
                GeneratedX = GeneratedX[:, :, :, 1, :]

                ABS_Y = tf.abs(GeneratedY - testY)
                ABS_X = tf.abs(GeneratedX - testX)
                black_board_rX = testX[0, :, :, 0]
                black_board_Y = GeneratedY[0, :, :, 0]
                black_board_absY = ABS_Y[0, :, :, 0]
                black_board_X = GeneratedX[0, :, :, 0]
                black_board_rY = testY[0, :, :, 0]
                black_board_absX = ABS_X[0, :, :, 0]
                step = i + 1
                plt.figure(figsize=(10, 10))  #图片大一点才可以承载像素
                plt.subplot(3, 2, 1)
                plt.title('real X')
                plt.imshow(black_board_rX, cmap='gray')
                plt.axis('off')
                plt.subplot(3, 2, 2)
                plt.title('fake Y')
                plt.imshow(black_board_Y, cmap='gray')
                plt.axis('off')
                plt.subplot(3, 2, 3)
                plt.title('ABS Y')
                plt.imshow(black_board_absY, cmap='hot')
                plt.axis('off')
                plt.subplot(3, 2, 4)
                plt.title('ABS X')
                plt.imshow(black_board_absX, cmap='hot')
                plt.axis('off')
                plt.subplot(3, 2, 5)
                plt.title('fake X')
                plt.imshow(black_board_X, cmap='gray')
                plt.axis('off')
                plt.subplot(3, 2, 6)
                plt.title('real Y')
                plt.imshow(black_board_rY, cmap='gray')
                plt.axis('off')
                plt.savefig(out_path + '/image_at_{}.png'.format(step))
                plt.close()
                img = Image.open(out_path + '/image_at_{}.png'.format(step))
                img = tf.reshape(np.array(img), shape=(1, 1000, 1000, 4))
                if (i + 1) == 1:
                    np.save(out_path + "/out_Y.npy", black_board_Y)
                    np.save(out_path + "/out_X.npy", black_board_X)
                with self.train_summary_writer.as_default():
                    ##########################
                    black_board_Y = tf.reshape(tf.constant(black_board_Y,
                                                           dtype=tf.float32),
                                               shape=[1, 128, 128, 1])
                    black_board_X = tf.reshape(tf.constant(black_board_X,
                                                           dtype=tf.float32),
                                               shape=[1, 128, 128, 1])
                    black_board_rY = tf.reshape(tf.constant(black_board_rY,
                                                            dtype=tf.float32),
                                                shape=[1, 128, 128, 1])
                    black_board_rX = tf.reshape(tf.constant(black_board_rX,
                                                            dtype=tf.float32),
                                                shape=[1, 128, 128, 1])
                    self.m_psnr_X2Y(
                        tf.image.psnr(black_board_Y,
                                      black_board_rY,
                                      1.0,
                                      name=None))
                    self.m_psnr_Y2X(
                        tf.image.psnr(black_board_X,
                                      black_board_rX,
                                      1.0,
                                      name=None))
                    self.m_ssim_X2Y(
                        tf.image.ssim(black_board_Y,
                                      black_board_rY,
                                      1,
                                      filter_size=11,
                                      filter_sigma=1.5,
                                      k1=0.01,
                                      k2=0.03))
                    self.m_ssim_Y2X(
                        tf.image.ssim(black_board_X,
                                      black_board_rX,
                                      1,
                                      filter_size=11,
                                      filter_sigma=1.5,
                                      k1=0.01,
                                      k2=0.03))
                    tf.summary.scalar('test_psnr_y',
                                      self.m_psnr_X2Y.result(),
                                      step=step)
                    tf.summary.scalar('test_psnr_x',
                                      self.m_psnr_Y2X.result(),
                                      step=step)
                    tf.summary.scalar('test_ssim_y',
                                      self.m_ssim_X2Y.result(),
                                      step=step)
                    tf.summary.scalar('test_ssim_x',
                                      self.m_ssim_Y2X.result(),
                                      step=step)
                    tf.summary.image("img", img, step=step)
                    dx1, dy1 = tf.image.image_gradients(black_board_Y)
                    dx2, dy2 = tf.image.image_gradients(black_board_rY)
                    dx_mean = tf.reduce_mean(tf.math.abs(dx1 - dx2))
                    dy_mean = tf.reduce_mean(tf.math.abs(dy1 - dy2))
                    IG_y = dy_mean + dx_mean
                    tf.summary.scalar('IG_y', IG_y, step=step)
                    dx1, dy1 = tf.image.image_gradients(black_board_X)
                    dx2, dy2 = tf.image.image_gradients(black_board_rX)
                    dx_mean = tf.reduce_mean(tf.math.abs(dx1 - dx2))
                    dy_mean = tf.reduce_mean(tf.math.abs(dy1 - dy2))
                    IG_x = dy_mean + dx_mean
                    tf.summary.scalar('IG_x', IG_x, step=step)
                    tf.summary.image("img", img, step=step)
                    result_buf.append([
                        i + 1,
                        self.m_psnr_X2Y.result().numpy(),
                        self.m_psnr_Y2X.result().numpy(),
                        self.m_ssim_X2Y.result().numpy(),
                        self.m_ssim_Y2X.result().numpy(),
                        IG_y.numpy(),
                        IG_x.numpy()
                    ])
                ##########################
                self.m_psnr_X2Y.reset_states()
                self.m_psnr_Y2X.reset_states()
                self.m_ssim_X2Y.reset_states()
                self.m_ssim_Y2X.reset_states()

            import csv
            headers = [
                'Instance', 'test_psnr_y', 'test_psnr_x', 'test_ssim_y',
                'test_ssim_x', 'IG_y', 'IG_x'
            ]
            rows = result_buf
            with open(out_path + '/result' + str(index) + '.csv', 'w') as f:
                f_csv = csv.writer(f)
                f_csv.writerow(headers)
                f_csv.writerows(rows)
class CycleGAN(tf.keras.Model):
    """
    模型只负责给定训练集和测试(验证)集后的操作
    """
    def __init__(self,
                train_set,
                test_set,
                loss_name="WGAN-GP",
                mixed_precision=False,
                learning_rate=2e-4,
                tmp_path=None,
                out_path=None):
        super(CycleGAN,self).__init__()
        #接收数据集和相关参数
        self.train_set = train_set
        self.test_set = test_set
        self.tmp_path = tmp_path
        self.out_path = out_path
        #定义模型
        self.G = networks.Generator(name="G_X2Y")
        self.F = networks.Generator(name="G_Y2X")
        if loss_name in ["WGAN-SN","WGAN-GP-SN"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",use_sigmoid=False,sn=True)
            self.Dx = networks.Discriminator(name="If_is_real_X",use_sigmoid=False,sn=True)
            self.loss_name = loss_name[:-3]
        elif loss_name in ["WGAN","WGAN-GP"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",use_sigmoid=False,sn=False)
            self.Dx = networks.Discriminator(name="If_is_real_X",use_sigmoid=False,sn=False)
            self.loss_name = loss_name
        elif loss_name in ["Vanilla","LSGAN"]:
            self.Dy = networks.Discriminator(name="If_is_real_Y",use_sigmoid=True,sn=False)
            self.Dx = networks.Discriminator(name="If_is_real_X",use_sigmoid=True,sn=False)
            self.loss_name = loss_name
        else: 
            raise ValueError("Do not support the loss "+loss_name)

        self.model_list=[self.G,self.F,self.Dy,self.Dx]
        #定义损失函数 优化器 记录等
        self.gan_loss = GanLoss(self.loss_name)
        self.optimizers_list = self.optimizers_config(mixed_precision=mixed_precision,learning_rate=learning_rate)
        self.mixed_precision = mixed_precision
        self.matrics_list = self.matrics_config()
        self.checkpoint_config()
        self.get_seed()
    def build(self,X_shape,Y_shape):
        """
        input_shape必须切片 因为在底层会被当做各层的输出shape而被改动
        """
        self.G.build(input_shape=X_shape[:])#G X->Y
        self.Dy.build(input_shape=Y_shape[:])#Dy Y or != Y
        self.F.build(input_shape=Y_shape[:])#F Y->X
        self.Dx.build(input_shape=X_shape[:])#Dx X or != X
        self.built = True

    def optimizers_config(self,mixed_precision=False,learning_rate=2e-4):
        self.G_optimizer = Adam(2e-4)
        self.Dy_optimizer = Adam(2e-4)
        self.F_optimizer = Adam(2e-4)
        self.Dx_optimizer = Adam(2e-4)
        if mixed_precision:
            self.G_optimizer=self.G_optimizer.get_mixed_precision()
            self.Dy_optimizer=self.Dy_optimizer.get_mixed_precision()
            self.F_optimizer=self.F_optimizer.get_mixed_precision()
            self.Dx_optimizer=self.Dx_optimizer.get_mixed_precision()
        return [self.G_optimizer,self.Dy_optimizer,self.F_optimizer,self.Dx_optimizer]
    def matrics_config(self):
        current_time =  datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_logdir = self.tmp_path+"/logs/" + current_time
        self.train_summary_writer = tf.summary.create_file_writer(train_logdir)

        return []

    def checkpoint_config(self):
        self.ckpt = tf.train.Checkpoint(step=tf.Variable(1),optimizer=self.optimizers_list,model=self.model_list,dataset=self.train_set)
        self.manager = tf.train.CheckpointManager(self.ckpt,self.tmp_path+'/tf_ckpts', max_to_keep=3)

    @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),tf.TensorSpec(shape=[1],dtype=tf.uint32)])
    def train_step_D(self,trainX,trainY,step):
        with tf.GradientTape(persistent=True) as D_tape:
            GeneratedY = self.G(trainX)
            Dy_real_out = self.Dy(trainY)
            Dy_fake_out = self.Dy(GeneratedY)

            GeneratedX = self.F(trainY)
            Dx_real_out = self.Dx(trainX)
            Dx_fake_out = self.Dx(GeneratedX)

            e = tf.random.uniform((trainY.shape[0],1,1,1),0.0,1.0)
            mid_Y = e*trainY+(1-e)*GeneratedY
            with tf.GradientTape() as gradient_penaltyY:
                gradient_penaltyY.watch(mid_Y)
                inner_loss = self.Dy(mid_Y)
            penalty = gradient_penaltyY.gradient(inner_loss,mid_Y)
            penalty_normY = 10.0*tf.math.square(tf.maximum(tf.norm(penalty,ord='euclidean'),1.0)-1)# 这是我自己认为的  因为只有梯度大于1的才需要优化哇

            e = tf.random.uniform((trainX.shape[0],1,1,1),0.0,1.0)
            mid_X = e*trainX+(1-e)*GeneratedX
            with tf.GradientTape() as gradient_penaltyX:
                gradient_penaltyX.watch(mid_X)
                inner_loss = self.Dx(mid_X)
            penalty = gradient_penaltyX.gradient(inner_loss,mid_X)
            penalty_normX = 10.0*tf.math.square(tf.maximum(tf.norm(penalty,ord='euclidean'),1.0)-1)

            Dy_loss = self.gan_loss.DiscriminatorLoss(Dy_real_out,Dy_fake_out)+tf.reduce_mean(penalty_normY)
            Dx_loss = self.gan_loss.DiscriminatorLoss(Dx_real_out,Dx_fake_out)+tf.reduce_mean(penalty_normX)

            if self.mixed_precision:
                scaled_Dy_loss = self.Dy_optimizer.get_scaled_loss(Dy_loss)
                scaled_Dx_loss = self.Dx_optimizer.get_scaled_loss(Dx_loss)

        if self.mixed_precision:
            scaled_gradients_of_Dy=D_tape.gradient(scaled_Dy_loss,self.Dy.trainable_variables)
            scaled_gradients_of_Dx=D_tape.gradient(scaled_Dx_loss,self.Dx.trainable_variables)
            gradients_of_Dy = self.Dy_optimizer.get_unscaled_gradients(scaled_gradients_of_Dy)
            gradients_of_Dx = self.Dx_optimizer.get_unscaled_gradients(scaled_gradients_of_Dx)
        else:
            gradients_of_Dy = D_tape.gradient(Dy_loss,self.Dy.trainable_variables)
            gradients_of_Dx = D_tape.gradient(Dx_loss,self.Dx.trainable_variables)

        self.Dy_optimizer.apply_gradients(zip(gradients_of_Dy,self.Dy.trainable_variables))
        self.Dx_optimizer.apply_gradients(zip(gradients_of_Dx,self.Dx.trainable_variables))
        return Dy_loss,Dx_loss
    @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),tf.TensorSpec(shape=[1],dtype=tf.uint32)])
    def train_step_G(self,trainX,trainY,step):
        with tf.GradientTape(persistent=True) as G_tape:
            GeneratedY = self.G(trainX)
            # Dy_real_out = self.Dy(trainY)
            Dy_fake_out = self.Dy(GeneratedY)

            GeneratedX = self.F(trainY)
            # Dx_real_out = self.Dx(trainX)
            Dx_fake_out = self.Dx(GeneratedX)

            cycle_consistent_loss_X2Y = tf.reduce_mean(tf.abs(self.F(GeneratedY)-trainX))
            cycle_consistent_loss_Y2X = tf.reduce_mean(tf.abs(self.G(GeneratedX)-trainY))
            cycle_consistent = cycle_consistent_loss_X2Y+cycle_consistent_loss_Y2X

            if step>=0:#先不进行像素梯度和重建损失的使用
                cycle_l = 10.0
            else:
                cycle_l = 10.0
            G_loss = self.gan_loss.GeneratorLoss(Dy_fake_out)+cycle_l*(cycle_consistent)
            F_loss = self.gan_loss.GeneratorLoss(Dx_fake_out)+cycle_l*(cycle_consistent)

            if self.mixed_precision:
                scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss)
                scaled_F_loss = self.F_optimizer.get_scaled_loss(F_loss)
        if self.mixed_precision:
            scaled_gradients_of_G=G_tape.gradient(scaled_G_loss,self.G.trainable_variables)
            scaled_gradients_of_F=G_tape.gradient(scaled_F_loss,self.F.trainable_variables)
            gradients_of_G = self.G_optimizer.get_unscaled_gradients(scaled_gradients_of_G)
            gradients_of_F = self.F_optimizer.get_unscaled_gradients(scaled_gradients_of_F)

        else:
            gradients_of_G = G_tape.gradient(G_loss,self.G.trainable_variables)
            gradients_of_F = G_tape.gradient(F_loss,self.F.trainable_variables)

        self.G_optimizer.apply_gradients(zip(gradients_of_G,self.G.trainable_variables))
        self.F_optimizer.apply_gradients(zip(gradients_of_F,self.F.trainable_variables))
        return G_loss,F_loss

    @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),tf.TensorSpec(shape=[1],dtype=tf.uint32)])
    def train_step(self,trainX,trainY,step):
        with tf.GradientTape(persistent=True) as cycle_type:
            GeneratedY = self.G(trainX)
            Dy_real_out = self.Dy(trainY)
            Dy_fake_out = self.Dy(GeneratedY)

            GeneratedX = self.F(trainY)
            Dx_real_out = self.Dx(trainX)
            Dx_fake_out = self.Dx(GeneratedX)

            cycle_consistent_loss_X2Y = tf.reduce_mean(tf.abs(self.F(GeneratedY)-trainX))
            cycle_consistent_loss_Y2X = tf.reduce_mean(tf.abs(self.G(GeneratedX)-trainY))
            cycle_consistent = cycle_consistent_loss_X2Y+cycle_consistent_loss_Y2X

            if step>=0:#先不进行像素梯度和重建损失的使用
                cycle_l = 10.0
            else:
                cycle_l = 10.0
            Dy_loss = self.gan_loss.DiscriminatorLoss(Dy_real_out,Dy_fake_out)
            Dx_loss = self.gan_loss.DiscriminatorLoss(Dx_real_out,Dx_fake_out)
            G_loss = self.gan_loss.GeneratorLoss(Dy_fake_out)+cycle_l*(cycle_consistent)
            F_loss = self.gan_loss.GeneratorLoss(Dx_fake_out)+cycle_l*(cycle_consistent)

        gradients_of_Dy = cycle_type.gradient(Dy_loss,self.Dy.trainable_variables)
        gradients_of_Dx = cycle_type.gradient(Dx_loss,self.Dx.trainable_variables)
        gradients_of_G = cycle_type.gradient(G_loss,self.G.trainable_variables)
        gradients_of_F = cycle_type.gradient(F_loss,self.F.trainable_variables)
        self.Dy_optimizer.apply_gradients(zip(gradients_of_Dy,self.Dy.trainable_variables))
        self.Dx_optimizer.apply_gradients(zip(gradients_of_Dx,self.Dx.trainable_variables))
        self.G_optimizer.apply_gradients(zip(gradients_of_G,self.G.trainable_variables))
        self.F_optimizer.apply_gradients(zip(gradients_of_F,self.F.trainable_variables))
        return G_loss,Dy_loss,F_loss,Dx_loss
    def train(self,epoches):
        self.ckpt.restore(self.manager.latest_checkpoint)
        for _ in range(epoches):
            start = time.time()
            for trainX,trainY in self.train_set:
                self.ckpt.step.assign_add(1)
                step = int(self.ckpt.step)
                if self.loss_name in ["WGAN","WGAN-GP"]:
                    for __ in range(1):
                        Dy_loss,Dx_loss = self.train_step_D(trainX,trainY,tf.constant(step,shape=[1],dtype=tf.uint32))
                    for __ in range(1):
                        G_loss,F_loss = self.train_step_G(trainX,trainY,tf.constant(step,shape=[1],dtype=tf.uint32))
                elif self.loss_name in ["Vanilla","LSGAN"]:
                    G_loss,Dy_loss,F_loss,Dx_loss = self.train_step(trainX,trainY,tf.constant(step,shape=[1],dtype=tf.uint32))
                else:
                    raise ValueError("Inner Error")
                if step % 100 == 0:
                    save_path = self.manager.save()
                    print("Saved checkpoint for step {}: {}".format(step,save_path))
                    
                    self.G.save_weights(self.tmp_path+'/weights_saved/G.ckpt')
                    self.Dy.save_weights(self.tmp_path+'/weights_saved/Dy.ckpt')
                    self.F.save_weights(self.tmp_path+'/weights_saved/F.ckpt')
                    self.Dx.save_weights(self.tmp_path+'/weights_saved/Dx.ckpt')
                    
                    self.wirte_summary(step=step,
                                       seed=self.seed,
                                       G=self.G,
                                       F=self.F,
                                       G_loss=G_loss,
                                       Dy_loss=Dy_loss,
                                       F_loss=F_loss,
                                       Dx_loss=Dx_loss,
                                       out_path=self.out_path)

                    print ('Time to next 100 step {} is {} sec'.format(step,time.time()-start))
                    start = time.time()
    def test(self,take_nums):
        self.ckpt.restore(self.manager.latest_checkpoint)
        for i,(testX,testY) in enumerate(self.test_set):
            GeneratedY = self.G(testX)
            GeneratedX = self.F(testY)
            plt.figure(figsize=(5,5))#图片大一点才可以承载像素
            plt.subplot(2,2,1)
            plt.title('real X')
            plt.imshow(testX[0,:,:,:])
            plt.axis('off')
            plt.subplot(2,2,2)
            plt.title('fake Y')
            plt.imshow(GeneratedY[0,:,:,:])
            plt.axis('off')
            plt.subplot(2,2,3)
            plt.title('fake X')
            plt.imshow(GeneratedX[0,:,:,:])
            plt.axis('off')
            plt.subplot(2,2,4)
            plt.title('real Y')
            plt.imshow(testY[0,:,:,:])
            plt.axis('off')
            plt.savefig(self.out_path+'/test/{}.png'.format(i))
            plt.close()
            img = Image.open(self.out_path+'/test/{}.png'.format(i))
            img = tf.reshape(np.array(img),shape=(1,500,500,4))

    def get_seed(self):
        seed_get = iter(self.test_set)
        seed = next(seed_get)
        print(seed[0].shape,seed[1].dtype)
        plt.imshow(seed[0][0,:,:,:])
        plt.show()
        plt.imshow(seed[1][0,:,:,:])
        plt.show()
        self.seed = seed 

    def wirte_summary(self,step,seed,G,F,G_loss,Dy_loss,F_loss,Dx_loss,out_path):
        testX,testY= seed
        GeneratedY = G(testX)
        GeneratedX = F(testY)
        plt.figure(figsize=(5,5))#图片大一点才可以承载像素
        plt.subplot(2,2,1)
        plt.title('real X')
        plt.imshow(testX[0,:,:,:])
        plt.axis('off')
        plt.subplot(2,2,2)
        plt.title('fake Y')
        plt.imshow(GeneratedY[0,:,:,:])
        plt.axis('off')
        plt.subplot(2,2,3)
        plt.title('fake X')
        plt.imshow(GeneratedX[0,:,:,:])
        plt.axis('off')
        plt.subplot(2,2,4)
        plt.title('real Y')
        plt.imshow(testY[0,:,:,:])
        plt.axis('off')
        plt.savefig(out_path+'/image_at_{}.png'.format(step))
        plt.close()
        img = Image.open(out_path+'/image_at_{}.png'.format(step))
        img = tf.reshape(np.array(img),shape=(1,500,500,4))
        
        with self.train_summary_writer.as_default():
            tf.summary.scalar('G_loss',G_loss,step=step)
            tf.summary.scalar('Dy_loss',Dy_loss,step=step)
            tf.summary.scalar('F_loss',F_loss,step=step)
            tf.summary.scalar('Dx_loss',Dx_loss,step=step)
            tf.summary.image("img",img,step=step)