class CycleGAN(tf.keras.Model): """ 模型只负责给定训练集和测试(验证)集后的操作 """ def __init__(self, train_set, test_set, loss_name="WGAN-GP", mixed_precision=False, learning_rate=2e-4, tmp_path=None, out_path=None): super(CycleGAN, self).__init__() #接收数据集和相关参数 self.train_set = train_set self.test_set = test_set self.tmp_path = tmp_path self.out_path = out_path #定义模型 self.G = networks.Generator(name="G_X2Y") self.F = networks.Generator(name="G_Y2X") if loss_name in ["WGAN-SN", "WGAN-GP-SN"]: self.Dy = networks.Discriminator(name="If_is_real_Y", use_sigmoid=False, sn=True) self.Dx = networks.Discriminator(name="If_is_real_X", use_sigmoid=False, sn=True) self.loss_name = loss_name[:-3] elif loss_name in ["WGAN", "WGAN-GP"]: self.Dy = networks.Discriminator(name="If_is_real_Y", use_sigmoid=False, sn=False) self.Dx = networks.Discriminator(name="If_is_real_X", use_sigmoid=False, sn=False) self.loss_name = loss_name elif loss_name in ["Vanilla", "LSGAN"]: self.Dy = networks.Discriminator(name="If_is_real_Y", use_sigmoid=True, sn=False) self.Dx = networks.Discriminator(name="If_is_real_X", use_sigmoid=True, sn=False) self.loss_name = loss_name else: raise ValueError("Do not support the loss " + loss_name) self.model_list = [self.G, self.F, self.Dy, self.Dx] #定义损失函数 优化器 记录等 self.gan_loss = GanLoss(self.loss_name) self.optimizers_list = self.optimizers_config( mixed_precision=mixed_precision, learning_rate=learning_rate) self.mixed_precision = mixed_precision self.matrics_list = self.matrics_config() self.checkpoint_config() self.get_seed() def build(self, X_shape, Y_shape): """ input_shape必须切片 因为在底层会被当做各层的输出shape而被改动 """ self.G.build(input_shape=X_shape[:]) #G X->Y self.Dy.build(input_shape=Y_shape[:]) #Dy Y or != Y self.F.build(input_shape=Y_shape[:]) #F Y->X self.Dx.build(input_shape=X_shape[:]) #Dx X or != X self.built = True def optimizers_config(self, mixed_precision=False, learning_rate=2e-4): self.G_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9) self.Dy_optimizer = Adam(learning_rate=4e-4, beta_1=0.0, beta_2=0.9) self.F_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9) self.Dx_optimizer = Adam(learning_rate=4e-4, beta_1=0.0, beta_2=0.9) if mixed_precision: self.G_optimizer = self.G_optimizer.get_mixed_precision() self.Dy_optimizer = self.Dy_optimizer.get_mixed_precision() self.F_optimizer = self.F_optimizer.get_mixed_precision() self.Dx_optimizer = self.Dx_optimizer.get_mixed_precision() return [ self.G_optimizer, self.Dy_optimizer, self.F_optimizer, self.Dx_optimizer ] def matrics_config(self): current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") train_logdir = self.tmp_path + "/logs/" + current_time self.train_summary_writer = tf.summary.create_file_writer(train_logdir) self.m_psnr_X2Y = tf.keras.metrics.Mean('psnr_y', dtype=tf.float32) self.m_psnr_Y2X = tf.keras.metrics.Mean('psnr_x', dtype=tf.float32) self.m_ssim_X2Y = tf.keras.metrics.Mean('ssim_y', dtype=tf.float32) self.m_ssim_Y2X = tf.keras.metrics.Mean('ssim_x', dtype=tf.float32) return [ self.m_psnr_X2Y, self.m_psnr_Y2X, self.m_ssim_X2Y, self.m_ssim_Y2X ] # return None def checkpoint_config(self): self.ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=self.optimizers_list, model=self.model_list, dataset=self.train_set) self.manager = tf.train.CheckpointManager(self.ckpt, self.tmp_path + '/tf_ckpts', max_to_keep=3) def pix_gradient(self, x): x = tf.reshape(x, shape=[1, 64, 64, 1]) #在各batch和通道上进行像素梯度 对2D单通道而言其实没必要reshape dx, dy = tf.image.image_gradients(x) return dx, dy @tf.function(input_signature=[ tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32), tf.TensorSpec(shape=[4], dtype=tf.int32), tf.TensorSpec(shape=[1], dtype=tf.uint32) ]) def train_step_D(self, trainX, trainY, maskX, maskY, wgp_shape, step): with tf.GradientTape(persistent=True) as D_tape: GeneratedY = self.G(trainX) GeneratedY = tf.multiply(GeneratedY, maskY) Dy_real_out = self.Dy(trainY) Dy_fake_out = self.Dy(GeneratedY) GeneratedX = self.F(trainY) GeneratedX = tf.multiply(GeneratedX, maskX) Dx_real_out = self.Dx(trainX) Dx_fake_out = self.Dx(GeneratedX) e = tf.random.uniform(shape=wgp_shape, minval=0.0, maxval=1.0) mid_Y = e * trainY + (1 - e) * GeneratedY with tf.GradientTape() as gradient_penaltyY: gradient_penaltyY.watch(mid_Y) inner_loss = self.Dy(mid_Y) penalty = gradient_penaltyY.gradient(inner_loss, mid_Y) penalty_normY = 10.0 * tf.math.square( tf.norm(tf.reshape(penalty, shape=[wgp_shape[0], -1]), ord=2, axis=-1) - 1) e = tf.random.uniform(shape=wgp_shape, minval=0.0, maxval=1.0) mid_X = e * trainX + (1 - e) * GeneratedX with tf.GradientTape() as gradient_penaltyX: gradient_penaltyX.watch(mid_X) inner_loss = self.Dx(mid_X) penalty = gradient_penaltyX.gradient(inner_loss, mid_X) penalty_normX = 10.0 * tf.math.square( tf.norm(tf.reshape(penalty, shape=[wgp_shape[0], -1]), ord=2, axis=-1) - 1) Dy_loss = self.gan_loss.DiscriminatorLoss( Dy_real_out, Dy_fake_out) + tf.reduce_mean(penalty_normY) Dx_loss = self.gan_loss.DiscriminatorLoss( Dx_real_out, Dx_fake_out) + tf.reduce_mean(penalty_normX) if self.mixed_precision: scaled_Dy_loss = self.Dy_optimizer.get_scaled_loss(Dy_loss) scaled_Dx_loss = self.Dx_optimizer.get_scaled_loss(Dx_loss) if self.mixed_precision: scaled_gradients_of_Dy = D_tape.gradient( scaled_Dy_loss, self.Dy.trainable_variables) scaled_gradients_of_Dx = D_tape.gradient( scaled_Dx_loss, self.Dx.trainable_variables) gradients_of_Dy = self.Dy_optimizer.get_unscaled_gradients( scaled_gradients_of_Dy) gradients_of_Dx = self.Dx_optimizer.get_unscaled_gradients( scaled_gradients_of_Dx) else: gradients_of_Dy = D_tape.gradient(Dy_loss, self.Dy.trainable_variables) gradients_of_Dx = D_tape.gradient(Dx_loss, self.Dx.trainable_variables) self.Dy_optimizer.apply_gradients( zip(gradients_of_Dy, self.Dy.trainable_variables)) self.Dx_optimizer.apply_gradients( zip(gradients_of_Dx, self.Dx.trainable_variables)) return Dy_loss, Dx_loss @tf.function(input_signature=[ tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32), tf.TensorSpec(shape=[4], dtype=tf.int32), tf.TensorSpec(shape=[1], dtype=tf.uint32) ]) def train_step_G(self, trainX, trainY, maskX, maskY, wgp_shape, step): with tf.GradientTape(persistent=True) as G_tape: GeneratedY = self.G(trainX) GeneratedY = tf.multiply(GeneratedY, maskY) # Dy_real_out = self.Dy(trainY) Dy_fake_out = self.Dy(GeneratedY) GeneratedX = self.F(trainY) GeneratedX = tf.multiply(GeneratedX, maskX) # Dx_real_out = self.Dx(trainX) Dx_fake_out = self.Dx(GeneratedX) cycle_consistent_loss_X2Y = tf.reduce_mean( tf.abs(self.F(GeneratedY) - trainX)) cycle_consistent_loss_Y2X = tf.reduce_mean( tf.abs(self.G(GeneratedX) - trainY)) cycle_consistent = cycle_consistent_loss_X2Y + cycle_consistent_loss_Y2X if step >= 0: #先不进行像素梯度和重建损失的使用 cycle_l = 10.0 else: cycle_l = 10.0 G_loss = self.gan_loss.GeneratorLoss( Dy_fake_out) + cycle_l * (cycle_consistent) F_loss = self.gan_loss.GeneratorLoss( Dx_fake_out) + cycle_l * (cycle_consistent) if self.mixed_precision: scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss) scaled_F_loss = self.F_optimizer.get_scaled_loss(F_loss) if self.mixed_precision: scaled_gradients_of_G = G_tape.gradient(scaled_G_loss, self.G.trainable_variables) scaled_gradients_of_F = G_tape.gradient(scaled_F_loss, self.F.trainable_variables) gradients_of_G = self.G_optimizer.get_unscaled_gradients( scaled_gradients_of_G) gradients_of_F = self.F_optimizer.get_unscaled_gradients( scaled_gradients_of_F) else: gradients_of_G = G_tape.gradient(G_loss, self.G.trainable_variables) gradients_of_F = G_tape.gradient(F_loss, self.F.trainable_variables) self.G_optimizer.apply_gradients( zip(gradients_of_G, self.G.trainable_variables)) self.F_optimizer.apply_gradients( zip(gradients_of_F, self.F.trainable_variables)) return G_loss, F_loss @tf.function(input_signature=[ tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32), tf.TensorSpec(shape=[4], dtype=tf.int32), tf.TensorSpec(shape=[1], dtype=tf.uint32) ]) def train_step(self, trainX, trainY, maskX, maskY, wgp_shape, step): with tf.GradientTape(persistent=True) as cycle_type: GeneratedY = self.G(trainX) GeneratedY = tf.multiply(GeneratedY, maskY) Dy_real_out = self.Dy(trainY) Dy_fake_out = self.Dy(GeneratedY) GeneratedX = self.F(trainY) GeneratedX = tf.multiply(GeneratedX, maskX) Dx_real_out = self.Dx(trainX) Dx_fake_out = self.Dx(GeneratedX) cycle_consistent_loss_X2Y = tf.reduce_mean( tf.abs(self.F(GeneratedY) - trainX)) cycle_consistent_loss_Y2X = tf.reduce_mean( tf.abs(self.G(GeneratedX) - trainY)) cycle_consistent = cycle_consistent_loss_X2Y + cycle_consistent_loss_Y2X if step >= 0: #先不进行像素梯度和重建损失的使用 cycle_l = 10.0 else: cycle_l = 10.0 Dy_loss = self.gan_loss.DiscriminatorLoss(Dy_real_out, Dy_fake_out) Dx_loss = self.gan_loss.DiscriminatorLoss(Dx_real_out, Dx_fake_out) G_loss = self.gan_loss.GeneratorLoss( Dy_fake_out) + cycle_l * (cycle_consistent) F_loss = self.gan_loss.GeneratorLoss( Dx_fake_out) + cycle_l * (cycle_consistent) gradients_of_Dy = cycle_type.gradient(Dy_loss, self.Dy.trainable_variables) gradients_of_Dx = cycle_type.gradient(Dx_loss, self.Dx.trainable_variables) gradients_of_G = cycle_type.gradient(G_loss, self.G.trainable_variables) gradients_of_F = cycle_type.gradient(F_loss, self.F.trainable_variables) self.Dy_optimizer.apply_gradients( zip(gradients_of_Dy, self.Dy.trainable_variables)) self.Dx_optimizer.apply_gradients( zip(gradients_of_Dx, self.Dx.trainable_variables)) self.G_optimizer.apply_gradients( zip(gradients_of_G, self.G.trainable_variables)) self.F_optimizer.apply_gradients( zip(gradients_of_F, self.F.trainable_variables)) return G_loss, Dy_loss, F_loss, Dx_loss def train(self, epoches): self.ckpt.restore(self.manager.latest_checkpoint) for _ in range(epoches): start = time.time() for trainX, trainY, maskX, maskY in self.train_set: self.ckpt.step.assign_add(1) step = int(self.ckpt.step) if self.loss_name in ["WGAN", "WGAN-GP"]: for __ in range(1): Dy_loss, Dx_loss = self.train_step_D( trainX, trainY, maskX, maskY, tf.constant([trainY.shape[0], 1, 1, 1], shape=[4], dtype=tf.int32), tf.constant(step, shape=[1], dtype=tf.uint32)) for __ in range(1): G_loss, F_loss = self.train_step_G( trainX, trainY, maskX, maskY, tf.constant([trainY.shape[0], 1, 1, 1], shape=[4], dtype=tf.int32), tf.constant(step, shape=[1], dtype=tf.uint32)) elif self.loss_name in ["Vanilla", "LSGAN"]: G_loss, Dy_loss, F_loss, Dx_loss = self.train_step( trainX, trainY, maskX, maskY, tf.constant([trainY.shape[0], 1, 1, 1], shape=[4], dtype=tf.int32), tf.constant(step, shape=[1], dtype=tf.uint32)) else: raise ValueError("Inner Error") if step % 100 == 0: save_path = self.manager.save() print("Saved checkpoint for step {}: {}".format( step, save_path)) self.G.save_weights(self.tmp_path + '/weights_saved/G.ckpt') self.Dy.save_weights(self.tmp_path + '/weights_saved/Dy.ckpt') self.F.save_weights(self.tmp_path + '/weights_saved/F.ckpt') self.Dx.save_weights(self.tmp_path + '/weights_saved/Dx.ckpt') self.wirte_summary(step=step, seed=self.seed, G=self.G, F=self.F, G_loss=G_loss, Dy_loss=Dy_loss, F_loss=F_loss, Dx_loss=Dx_loss, out_path=self.out_path) print('Time to next 100 step {} is {} sec'.format( step, time.time() - start)) start = time.time() def get_seed(self): seed_get = iter(self.test_set) testX, testY, maskX, maskY = next(seed_get) print(testX.shape, testY.dtype, maskX.dtype, maskY.shape) plt.imshow(testX[0, :, :, 0], cmap='gray') plt.show() plt.imshow(testY[0, :, :, 0], cmap='gray') plt.show() plt.imshow(maskX[0, :, :, 0], cmap='gray') plt.show() plt.imshow(maskY[0, :, :, 0], cmap='gray') plt.show() self.seed = testX, testY, maskX, maskY def wirte_summary(self, step, seed, G, F, G_loss, Dy_loss, F_loss, Dx_loss, out_path): testX, testY, maskX, maskY = seed GeneratedY = G(testX) GeneratedY = tf.multiply(GeneratedY, maskX) GeneratedX = F(testY) GeneratedX = tf.multiply(GeneratedX, maskY) #测试时mask正好相反 因为只知道原来模态和原来模态的mask plt.figure(figsize=(5, 5)) #图片大一点才可以承载像素 plt.subplot(2, 2, 1) plt.title('real X') plt.imshow(testX[0, :, :, 0], cmap='gray') plt.axis('off') plt.subplot(2, 2, 2) plt.title('fake Y') plt.imshow(GeneratedY[0, :, :, 0], cmap='gray') plt.axis('off') plt.subplot(2, 2, 3) plt.title('fake X') plt.imshow(GeneratedX[0, :, :, 0], cmap='gray') plt.axis('off') plt.subplot(2, 2, 4) plt.title('real Y') plt.imshow(testY[0, :, :, 0], cmap='gray') plt.axis('off') plt.savefig(out_path + '/image_at_{}.png'.format(step)) plt.close() img = Image.open(out_path + '/image_at_{}.png'.format(step)) img = tf.reshape(np.array(img), shape=(1, 500, 500, 4)) with self.train_summary_writer.as_default(): ########################## self.m_psnr_X2Y(tf.image.psnr(GeneratedY, testY, 1.0, name=None)) self.m_psnr_Y2X(tf.image.psnr(GeneratedX, testX, 1.0, name=None)) self.m_ssim_X2Y( tf.image.ssim(GeneratedY, testY, 1, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)) self.m_ssim_Y2X( tf.image.ssim(GeneratedX, testX, 1, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)) tf.summary.scalar('G_loss', G_loss, step=step) tf.summary.scalar('Dy_loss', Dy_loss, step=step) tf.summary.scalar('F_loss', F_loss, step=step) tf.summary.scalar('Dx_loss', Dx_loss, step=step) tf.summary.scalar('test_psnr_y', self.m_psnr_X2Y.result(), step=step) tf.summary.scalar('test_psnr_x', self.m_psnr_Y2X.result(), step=step) tf.summary.scalar('test_ssim_y', self.m_ssim_X2Y.result(), step=step) tf.summary.scalar('test_ssim_x', self.m_ssim_Y2X.result(), step=step) tf.summary.image("img", img, step=step) ########################## self.m_psnr_X2Y.reset_states() self.m_psnr_Y2X.reset_states() self.m_ssim_X2Y.reset_states() self.m_ssim_Y2X.reset_states() def test(self): self.ckpt.restore(self.manager.latest_checkpoint) step = 0 black_board_X = np.zeros(shape=[240, 240], dtype=np.float32) black_board_Y = np.zeros(shape=[240, 240], dtype=np.float32) black_board_rX = np.zeros(shape=[240, 240], dtype=np.float32) black_board_rY = np.zeros(shape=[240, 240], dtype=np.float32) for i, (testX, testY, maskX, maskY) in enumerate(self.test_set): GeneratedY = self.G(testX) GeneratedY = tf.multiply(GeneratedY, maskX) GeneratedX = self.F(testY) GeneratedX = tf.multiply(GeneratedX, maskY) #测试时mask正好相反 因为只知道原来模态和原来模态的mask if (i + 1) % 4 == 1: black_board_Y[48:175 + 1, 22:149 + 1] += GeneratedY.numpy()[0, :, :, 0] black_board_X[48:175 + 1, 22:149 + 1] += GeneratedX.numpy()[0, :, :, 0] black_board_rY[48:175 + 1, 22:149 + 1] += testY.numpy()[0, :, :, 0] black_board_rX[48:175 + 1, 22:149 + 1] += testX.numpy()[0, :, :, 0] elif (i + 1) % 4 == 2: black_board_Y[48:175 + 1, 90:217 + 1] += GeneratedY.numpy()[0, :, :, 0] black_board_X[48:175 + 1, 90:217 + 1] += GeneratedX.numpy()[0, :, :, 0] black_board_rY[48:175 + 1, 90:217 + 1] += testY.numpy()[0, :, :, 0] black_board_rX[48:175 + 1, 90:217 + 1] += testX.numpy()[0, :, :, 0] elif (i + 1) % 4 == 3: black_board_Y[64:191 + 1, 22:149 + 1] += GeneratedY.numpy()[0, :, :, 0] black_board_X[64:191 + 1, 22:149 + 1] += GeneratedX.numpy()[0, :, :, 0] black_board_rY[64:191 + 1, 22:149 + 1] += testY.numpy()[0, :, :, 0] black_board_rX[64:191 + 1, 22:149 + 1] += testX.numpy()[0, :, :, 0] elif (i + 1) % 4 == 0: black_board_Y[64:191 + 1, 90:217 + 1] += GeneratedY.numpy()[0, :, :, 0] black_board_X[64:191 + 1, 90:217 + 1] += GeneratedX.numpy()[0, :, :, 0] black_board_rY[64:191 + 1, 90:217 + 1] += testY.numpy()[0, :, :, 0] black_board_rX[64:191 + 1, 90:217 + 1] += testX.numpy()[0, :, :, 0] #norm black_board_Y[64:175 + 1, :] = black_board_Y[64:175 + 1, :] / 2.0 black_board_Y[:, 90:149 + 1] = black_board_Y[:, 90:149 + 1] / 2.0 black_board_X[64:175 + 1, :] = black_board_X[64:175 + 1, :] / 2.0 black_board_X[:, 90:149 + 1] = black_board_X[:, 90:149 + 1] / 2.0 black_board_rY[64:175 + 1, :] = black_board_rY[64:175 + 1, :] / 2.0 black_board_rY[:, 90:149 + 1] = black_board_rY[:, 90:149 + 1] / 2.0 black_board_rX[64:175 + 1, :] = black_board_rX[64:175 + 1, :] / 2.0 black_board_rX[:, 90:149 + 1] = black_board_rX[:, 90:149 + 1] / 2.0 else: raise ValueError("inner error") out_path = self.out_path if (i + 1) % 4 == 0: step += 1 plt.figure(figsize=(10, 10)) #图片大一点才可以承载像素 plt.subplot(2, 2, 1) plt.title('real X') plt.imshow(black_board_rX, cmap='gray') plt.axis('off') plt.subplot(2, 2, 2) plt.title('fake Y') plt.imshow(black_board_Y, cmap='gray') plt.axis('off') plt.subplot(2, 2, 3) plt.title('fake X') plt.imshow(black_board_X, cmap='gray') plt.axis('off') plt.subplot(2, 2, 4) plt.title('real Y') plt.imshow(black_board_rY, cmap='gray') plt.axis('off') plt.savefig(out_path + '/test/image_at_{}.png'.format(step)) plt.close() img = Image.open(out_path + '/test/image_at_{}.png'.format(step)) img = tf.reshape(np.array(img), shape=(1, 1000, 1000, 4)) with self.train_summary_writer.as_default(): ########################## black_board_Y = tf.reshape(tf.constant(black_board_Y, dtype=tf.float32), shape=[1, 240, 240, 1]) black_board_X = tf.reshape(tf.constant(black_board_X, dtype=tf.float32), shape=[1, 240, 240, 1]) black_board_rY = tf.reshape(tf.constant(black_board_rY, dtype=tf.float32), shape=[1, 240, 240, 1]) black_board_rX = tf.reshape(tf.constant(black_board_rX, dtype=tf.float32), shape=[1, 240, 240, 1]) self.m_psnr_X2Y( tf.image.psnr(black_board_Y, black_board_rY, 1.0, name=None)) self.m_psnr_Y2X( tf.image.psnr(black_board_X, black_board_rX, 1.0, name=None)) self.m_ssim_X2Y( tf.image.ssim(black_board_Y, black_board_rY, 1, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)) self.m_ssim_Y2X( tf.image.ssim(black_board_X, black_board_rX, 1, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)) tf.summary.scalar('test_psnr_y', self.m_psnr_X2Y.result(), step=step) tf.summary.scalar('test_psnr_x', self.m_psnr_Y2X.result(), step=step) tf.summary.scalar('test_ssim_y', self.m_ssim_X2Y.result(), step=step) tf.summary.scalar('test_ssim_x', self.m_ssim_Y2X.result(), step=step) tf.summary.image("img", img, step=step) ########################## self.m_psnr_X2Y.reset_states() self.m_psnr_Y2X.reset_states() self.m_ssim_X2Y.reset_states() self.m_ssim_Y2X.reset_states() black_board_X = np.zeros(shape=[240, 240], dtype=np.float32) black_board_Y = np.zeros(shape=[240, 240], dtype=np.float32) black_board_rX = np.zeros(shape=[240, 240], dtype=np.float32) black_board_rY = np.zeros(shape=[240, 240], dtype=np.float32)
class CycleGAN(tf.keras.Model): """ 模型只负责给定训练集和测试(验证)集后的操作 """ def __init__(self, train_set, test_set, loss_name="WGAN-GP", mixed_precision=False, learning_rate=2e-4, tmp_path=None, out_path=None): super(CycleGAN, self).__init__() #接收数据集和相关参数 self.train_set = train_set self.test_set = test_set self.tmp_path = tmp_path self.out_path = out_path #定义模型 self.G = networks.Generator(name="G_X2Y") self.F = networks.Generator(name="G_Y2X") if loss_name in ["WGAN-SN", "WGAN-GP-SN"]: self.Dy = networks.Discriminator(name="If_is_real_Y", use_sigmoid=False, sn=True) self.Dx = networks.Discriminator(name="If_is_real_X", use_sigmoid=False, sn=True) self.loss_name = loss_name[:-3] elif loss_name in ["WGAN", "WGAN-GP"]: self.Dy = networks.Discriminator(name="If_is_real_Y", use_sigmoid=False, sn=False) self.Dx = networks.Discriminator(name="If_is_real_X", use_sigmoid=False, sn=False) self.loss_name = loss_name elif loss_name in ["Vanilla-SN", "LSGAN-SN"]: self.Dy = networks.Discriminator(name="If_is_real_Y", use_sigmoid=True, sn=True) self.Dx = networks.Discriminator(name="If_is_real_X", use_sigmoid=True, sn=True) self.loss_name = loss_name[:-3] elif loss_name in ["Vanilla", "LSGAN"]: self.Dy = networks.Discriminator(name="If_is_real_Y", use_sigmoid=True, sn=False) self.Dx = networks.Discriminator(name="If_is_real_X", use_sigmoid=True, sn=False) self.loss_name = loss_name else: raise ValueError("Do not support the loss " + loss_name) self.vgg = vgg16() self.model_list = [self.G, self.F, self.Dy, self.Dx] #定义损失函数 优化器 记录等 self.gan_loss = GanLoss(self.loss_name) self.optimizers_list = self.optimizers_config( mixed_precision=mixed_precision, learning_rate=learning_rate) self.mixed_precision = mixed_precision self.matrics_list = self.matrics_config() self.checkpoint_config() self.get_seed() def build(self, X_shape, Y_shape): """ input_shape必须切片 因为在底层会被当做各层的输出shape而被改动 """ self.G.build(input_shape=X_shape[:]) #G X->Y self.Dy.build(input_shape=Y_shape[:]) #Dy Y or != Y self.F.build(input_shape=Y_shape[:]) #F Y->X self.Dx.build(input_shape=X_shape[:]) #Dx X or != X self.built = True def optimizers_config(self, mixed_precision=False, learning_rate=2e-4): self.G_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9) self.Dy_optimizer = Adam(learning_rate=4e-4, beta_1=0.0, beta_2=0.9) self.F_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9) self.Dx_optimizer = Adam(learning_rate=4e-4, beta_1=0.0, beta_2=0.9) # self.G_optimizer = Adam(learning_rate=2e-4) # self.Dy_optimizer = Adam(learning_rate=2e-4) # self.F_optimizer = Adam(learning_rate=2e-4) # self.Dx_optimizer = Adam(learning_rate=2e-4) if mixed_precision: self.G_optimizer = self.G_optimizer.get_mixed_precision() self.Dy_optimizer = self.Dy_optimizer.get_mixed_precision() self.F_optimizer = self.F_optimizer.get_mixed_precision() self.Dx_optimizer = self.Dx_optimizer.get_mixed_precision() return [ self.G_optimizer, self.Dy_optimizer, self.F_optimizer, self.Dx_optimizer ] def matrics_config(self): current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") train_logdir = self.tmp_path + "/logs/" + current_time self.train_summary_writer = tf.summary.create_file_writer(train_logdir) self.m_psnr_X2Y = tf.keras.metrics.Mean('psnr_y', dtype=tf.float32) self.m_psnr_Y2X = tf.keras.metrics.Mean('psnr_x', dtype=tf.float32) self.m_ssim_X2Y = tf.keras.metrics.Mean('ssim_y', dtype=tf.float32) self.m_ssim_Y2X = tf.keras.metrics.Mean('ssim_x', dtype=tf.float32) return [ self.m_psnr_X2Y, self.m_psnr_Y2X, self.m_ssim_X2Y, self.m_ssim_Y2X ] # return None def checkpoint_config(self): self.ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=self.optimizers_list, model=self.model_list, dataset=self.train_set) self.manager = tf.train.CheckpointManager(self.ckpt, self.tmp_path + '/tf_ckpts', max_to_keep=3) def pix_gradient(self, x): x = tf.reshape(x, shape=[1, 64, 64, 1]) #在各batch和通道上进行像素梯度 对2D单通道而言其实没必要reshape dx, dy = tf.image.image_gradients(x) return dx, dy @tf.function(input_signature=[ tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32) ]) def train_step_D(self, trainX, trainY, maskX, maskY): with tf.GradientTape(persistent=True) as D_tape: GeneratedY = self.G(trainX) GeneratedY = tf.multiply(GeneratedY, maskY) Dy_real_out = self.Dy(trainY) Dy_fake_out = self.Dy(GeneratedY) GeneratedX = self.F(trainY) GeneratedX = tf.multiply(GeneratedX, maskX) Dx_real_out = self.Dx(trainX) Dx_fake_out = self.Dx(GeneratedX) e = tf.random.uniform(shape=self.wgp_shape, minval=0.0, maxval=1.0) mid_Y = e * trainY + (1 - e) * GeneratedY with tf.GradientTape() as gradient_penaltyY: gradient_penaltyY.watch(mid_Y) inner_loss = self.Dy(mid_Y) penalty = gradient_penaltyY.gradient(inner_loss, mid_Y) penalty_normY = 10.0 * tf.math.square( tf.norm(tf.reshape(penalty, shape=[self.wgp_shape[0], -1]), ord=2, axis=-1) - 1) e = tf.random.uniform(shape=self.wgp_shape, minval=0.0, maxval=1.0) mid_X = e * trainX + (1 - e) * GeneratedX with tf.GradientTape() as gradient_penaltyX: gradient_penaltyX.watch(mid_X) inner_loss = self.Dx(mid_X) penalty = gradient_penaltyX.gradient(inner_loss, mid_X) penalty_normX = 10.0 * tf.math.square( tf.norm(tf.reshape(penalty, shape=[self.wgp_shape[0], -1]), ord=2, axis=-1) - 1) Dy_loss = self.gan_loss.DiscriminatorLoss( Dy_real_out, Dy_fake_out) + tf.reduce_mean(penalty_normY) Dx_loss = self.gan_loss.DiscriminatorLoss( Dx_real_out, Dx_fake_out) + tf.reduce_mean(penalty_normX) if self.mixed_precision: scaled_Dy_loss = self.Dy_optimizer.get_scaled_loss(Dy_loss) scaled_Dx_loss = self.Dx_optimizer.get_scaled_loss(Dx_loss) if self.mixed_precision: scaled_gradients_of_Dy = D_tape.gradient( scaled_Dy_loss, self.Dy.trainable_variables) scaled_gradients_of_Dx = D_tape.gradient( scaled_Dx_loss, self.Dx.trainable_variables) gradients_of_Dy = self.Dy_optimizer.get_unscaled_gradients( scaled_gradients_of_Dy) gradients_of_Dx = self.Dx_optimizer.get_unscaled_gradients( scaled_gradients_of_Dx) else: gradients_of_Dy = D_tape.gradient(Dy_loss, self.Dy.trainable_variables) gradients_of_Dx = D_tape.gradient(Dx_loss, self.Dx.trainable_variables) self.Dy_optimizer.apply_gradients( zip(gradients_of_Dy, self.Dy.trainable_variables)) self.Dx_optimizer.apply_gradients( zip(gradients_of_Dx, self.Dx.trainable_variables)) return Dy_loss, Dx_loss @tf.function(input_signature=[ tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32) ]) def train_step_G(self, trainX, trainY, maskX, maskY): with tf.GradientTape(persistent=True) as G_tape: GeneratedY = self.G(trainX) GeneratedY = tf.multiply(GeneratedY, maskY) # Dy_real_out = self.Dy(trainY) Dy_fake_out = self.Dy(GeneratedY) GeneratedX = self.F(trainY) GeneratedX = tf.multiply(GeneratedX, maskX) # Dx_real_out = self.Dx(trainX) Dx_fake_out = self.Dx(GeneratedX) cycle_consistent_loss_X2Y = tf.reduce_mean( tf.abs(self.F(GeneratedY) - trainX)) cycle_consistent_loss_Y2X = tf.reduce_mean( tf.abs(self.G(GeneratedX) - trainY)) cycle_consistent = cycle_consistent_loss_X2Y + cycle_consistent_loss_Y2X fake_Y_perceptual = self.vgg(GeneratedY) real_Y_perceptual = self.vgg(trainY) fake_X_perceptual = self.vgg(GeneratedX) real_X_perceptual = self.vgg(trainX) reconstruction_loss_X2Y = 0 reconstruction_loss_Y2X = 0 for i in range(7): reconstruction_loss_X2Y += 0.14 * tf.reduce_mean( tf.abs(fake_Y_perceptual[i] - real_Y_perceptual[i])) reconstruction_loss_Y2X += 0.14 * tf.reduce_mean( tf.abs(fake_X_perceptual[i] - real_X_perceptual[i])) # reconstruction_loss_X2Y = tf.reduce_mean(tf.abs(GeneratedY-trainY)) # reconstruction_loss_Y2X = tf.reduce_mean(tf.abs(GeneratedX-trainX)) G_loss = self.gan_loss.GeneratorLoss( Dy_fake_out ) + 10.0 * cycle_consistent + reconstruction_loss_X2Y F_loss = self.gan_loss.GeneratorLoss( Dx_fake_out ) + 10.0 * cycle_consistent + reconstruction_loss_Y2X if self.mixed_precision: scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss) scaled_F_loss = self.F_optimizer.get_scaled_loss(F_loss) if self.mixed_precision: scaled_gradients_of_G = G_tape.gradient(scaled_G_loss, self.G.trainable_variables) scaled_gradients_of_F = G_tape.gradient(scaled_F_loss, self.F.trainable_variables) gradients_of_G = self.G_optimizer.get_unscaled_gradients( scaled_gradients_of_G) gradients_of_F = self.F_optimizer.get_unscaled_gradients( scaled_gradients_of_F) else: gradients_of_G = G_tape.gradient(G_loss, self.G.trainable_variables) gradients_of_F = G_tape.gradient(F_loss, self.F.trainable_variables) self.G_optimizer.apply_gradients( zip(gradients_of_G, self.G.trainable_variables)) self.F_optimizer.apply_gradients( zip(gradients_of_F, self.F.trainable_variables)) return G_loss, F_loss @tf.function(input_signature=[ tf.TensorSpec(shape=global_input_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_input_Y_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_X_shape, dtype=tf.float32), tf.TensorSpec(shape=global_mask_Y_shape, dtype=tf.float32) ]) def train_step(self, trainX, trainY, maskX, maskY): with tf.GradientTape(persistent=True) as cycle_type: GeneratedY = self.G(trainX) GeneratedY = tf.multiply(GeneratedY, maskY) Dy_real_out = self.Dy(trainY) Dy_fake_out = self.Dy(GeneratedY) GeneratedX = self.F(trainY) GeneratedX = tf.multiply(GeneratedX, maskX) Dx_real_out = self.Dx(trainX) Dx_fake_out = self.Dx(GeneratedX) cycle_consistent_loss_X2Y = tf.reduce_mean( tf.abs(self.F(GeneratedY) - trainX)) cycle_consistent_loss_Y2X = tf.reduce_mean( tf.abs(self.G(GeneratedX) - trainY)) cycle_consistent = cycle_consistent_loss_X2Y + cycle_consistent_loss_Y2X fake_Y_perceptual = self.vgg(GeneratedY) real_Y_perceptual = self.vgg(trainY) fake_X_perceptual = self.vgg(GeneratedX) real_X_perceptual = self.vgg(trainX) reconstruction_loss_X2Y = 0 reconstruction_loss_Y2X = 0 for i in range(7): reconstruction_loss_X2Y += 0.14 * tf.reduce_mean( tf.abs(fake_Y_perceptual[i] - real_Y_perceptual[i])) reconstruction_loss_Y2X += 0.14 * tf.reduce_mean( tf.abs(fake_X_perceptual[i] - real_X_perceptual[i])) # reconstruction_loss_X2Y = tf.reduce_mean(tf.abs(GeneratedY-trainY)) # reconstruction_loss_Y2X = tf.reduce_mean(tf.abs(GeneratedX-trainX)) Dy_loss = self.gan_loss.DiscriminatorLoss(Dy_real_out, Dy_fake_out) Dx_loss = self.gan_loss.DiscriminatorLoss(Dx_real_out, Dx_fake_out) G_loss = self.gan_loss.GeneratorLoss(Dy_fake_out) + 10.0 * ( cycle_consistent) + reconstruction_loss_X2Y F_loss = self.gan_loss.GeneratorLoss(Dx_fake_out) + 10.0 * ( cycle_consistent) + reconstruction_loss_Y2X gradients_of_Dy = cycle_type.gradient(Dy_loss, self.Dy.trainable_variables) gradients_of_Dx = cycle_type.gradient(Dx_loss, self.Dx.trainable_variables) gradients_of_G = cycle_type.gradient(G_loss, self.G.trainable_variables) gradients_of_F = cycle_type.gradient(F_loss, self.F.trainable_variables) self.Dy_optimizer.apply_gradients( zip(gradients_of_Dy, self.Dy.trainable_variables)) self.Dx_optimizer.apply_gradients( zip(gradients_of_Dx, self.Dx.trainable_variables)) self.G_optimizer.apply_gradients( zip(gradients_of_G, self.G.trainable_variables)) self.F_optimizer.apply_gradients( zip(gradients_of_F, self.F.trainable_variables)) return G_loss, Dy_loss, F_loss, Dx_loss def train(self, epoches): self.ckpt.restore(self.manager.latest_checkpoint) my_step = int(self.ckpt.step) stop_flag = 0 for _ in range(epoches): start = time.time() for trainX, trainY, maskX, maskY in self.train_set: my_step += 1 self.ckpt.step.assign_add(1) step = int(self.ckpt.step) ###必要的超参数、变化的学习率都定义在这里 self.l = 10.0 * (1.0 / (step * 0.1 + 1.0)) self.wgp_shape = [trainY.shape[0], 1, 1, 1, 1] #3D 多一个通道 ### if self.loss_name in ["WGAN", "WGAN-GP"]: for __ in range(1): Dy_loss, Dx_loss = self.train_step_D( trainX, trainY, maskX, maskY) for __ in range(1): G_loss, F_loss = self.train_step_G( trainX, trainY, maskX, maskY) elif self.loss_name in ["Vanilla", "LSGAN"]: G_loss, Dy_loss, F_loss, Dx_loss = self.train_step( trainX, trainY, maskX, maskY) else: raise ValueError("Inner Error") if step % 100 == 0: save_path = self.manager.save() print("Saved checkpoint for step {}: {}".format( step, save_path)) self.G.save_weights(self.tmp_path + '/weights_saved/G.ckpt') self.Dy.save_weights(self.tmp_path + '/weights_saved/Dy.ckpt') self.F.save_weights(self.tmp_path + '/weights_saved/F.ckpt') self.Dx.save_weights(self.tmp_path + '/weights_saved/Dx.ckpt') self.wirte_summary(step=step, seed=self.seed, G=self.G, F=self.F, G_loss=G_loss, Dy_loss=Dy_loss, F_loss=F_loss, Dx_loss=Dx_loss, out_path=self.out_path) print('Time to next 100 step {} is {} sec'.format( step, time.time() - start)) start = time.time() if step == 80200: stop_flag = 1 break if stop_flag == 1: break def get_seed(self): seed_get = iter(self.test_set) testX, testY, maskX, maskY = next(seed_get) print(testX.shape, testY.dtype, maskX.dtype, maskY.shape) plt.imshow(testX[0, :, :, 1, 0], cmap='gray') plt.show() plt.imshow(testY[0, :, :, 1, 0], cmap='gray') plt.show() plt.imshow(maskX[0, :, :, 1, 0], cmap='gray') plt.show() plt.imshow(maskY[0, :, :, 1, 0], cmap='gray') plt.show() self.seed = testX, testY, maskX, maskY def wirte_summary(self, step, seed, G, F, G_loss, Dy_loss, F_loss, Dx_loss, out_path): testX, testY, maskX, maskY = seed GeneratedY = G(testX) GeneratedY = tf.multiply(GeneratedY, maskX) GeneratedX = F(testY) GeneratedX = tf.multiply(GeneratedX, maskY) testX = testX[:, :, :, 1, :] testY = testY[:, :, :, 1, :] maskX = maskX[:, :, :, 1, :] maskY = maskY[:, :, :, 1, :] GeneratedY = GeneratedY[:, :, :, 1, :] GeneratedX = GeneratedX[:, :, :, 1, :] plt.figure(figsize=(5, 5)) #图片大一点才可以承载像素 plt.subplot(2, 2, 1) plt.title('real X') plt.imshow(testX[0, :, :, 0], cmap='gray') plt.axis('off') plt.subplot(2, 2, 2) plt.title('fake Y') plt.imshow(GeneratedY[0, :, :, 0], cmap='gray') plt.axis('off') plt.subplot(2, 2, 3) plt.title('fake X') plt.imshow(GeneratedX[0, :, :, 0], cmap='gray') plt.axis('off') plt.subplot(2, 2, 4) plt.title('real Y') plt.imshow(testY[0, :, :, 0], cmap='gray') plt.axis('off') plt.savefig(out_path + '/image_at_{}.png'.format(step)) plt.close() img = Image.open(out_path + '/image_at_{}.png'.format(step)) img = tf.reshape(np.array(img), shape=(1, 500, 500, 4)) with self.train_summary_writer.as_default(): ########################## self.m_psnr_X2Y(tf.image.psnr(GeneratedY, testY, 1.0, name=None)) self.m_psnr_Y2X(tf.image.psnr(GeneratedX, testX, 1.0, name=None)) self.m_ssim_X2Y( tf.image.ssim(GeneratedY, testY, 1, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)) self.m_ssim_Y2X( tf.image.ssim(GeneratedX, testX, 1, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)) tf.summary.scalar('G_loss', G_loss, step=step) tf.summary.scalar('Dy_loss', Dy_loss, step=step) tf.summary.scalar('F_loss', F_loss, step=step) tf.summary.scalar('Dx_loss', Dx_loss, step=step) tf.summary.scalar('test_psnr_y', self.m_psnr_X2Y.result(), step=step) tf.summary.scalar('test_psnr_x', self.m_psnr_Y2X.result(), step=step) tf.summary.scalar('test_ssim_y', self.m_ssim_X2Y.result(), step=step) tf.summary.scalar('test_ssim_x', self.m_ssim_Y2X.result(), step=step) tf.summary.image("img", img, step=step) ########################## self.m_psnr_X2Y.reset_states() self.m_psnr_Y2X.reset_states() self.m_ssim_X2Y.reset_states() self.m_ssim_Y2X.reset_states() def test(self): buf = self.manager.latest_checkpoint buf = buf[:-3] for index, temp_point in enumerate(["800", "801", "802"]): self.ckpt.restore(buf + temp_point) step = 0 out_path = self.out_path + "/test" if not os.path.exists(out_path): os.makedirs(out_path) result_buf = [] for i, (testX, testY, maskX, maskY) in enumerate(self.test_set): GeneratedY = self.G(testX) GeneratedY = tf.multiply(GeneratedY, maskX) GeneratedX = self.F(testY) GeneratedX = tf.multiply( GeneratedX, maskY) #测试时mask正好相反 因为只知道原来模态和原来模态的mask testX = testX[:, :, :, 1, :] testY = testY[:, :, :, 1, :] maskX = maskX[:, :, :, 1, :] maskY = maskY[:, :, :, 1, :] GeneratedY = GeneratedY[:, :, :, 1, :] GeneratedX = GeneratedX[:, :, :, 1, :] ABS_Y = tf.abs(GeneratedY - testY) ABS_X = tf.abs(GeneratedX - testX) black_board_rX = testX[0, :, :, 0] black_board_Y = GeneratedY[0, :, :, 0] black_board_absY = ABS_Y[0, :, :, 0] black_board_X = GeneratedX[0, :, :, 0] black_board_rY = testY[0, :, :, 0] black_board_absX = ABS_X[0, :, :, 0] step = i + 1 plt.figure(figsize=(10, 10)) #图片大一点才可以承载像素 plt.subplot(3, 2, 1) plt.title('real X') plt.imshow(black_board_rX, cmap='gray') plt.axis('off') plt.subplot(3, 2, 2) plt.title('fake Y') plt.imshow(black_board_Y, cmap='gray') plt.axis('off') plt.subplot(3, 2, 3) plt.title('ABS Y') plt.imshow(black_board_absY, cmap='hot') plt.axis('off') plt.subplot(3, 2, 4) plt.title('ABS X') plt.imshow(black_board_absX, cmap='hot') plt.axis('off') plt.subplot(3, 2, 5) plt.title('fake X') plt.imshow(black_board_X, cmap='gray') plt.axis('off') plt.subplot(3, 2, 6) plt.title('real Y') plt.imshow(black_board_rY, cmap='gray') plt.axis('off') plt.savefig(out_path + '/image_at_{}.png'.format(step)) plt.close() img = Image.open(out_path + '/image_at_{}.png'.format(step)) img = tf.reshape(np.array(img), shape=(1, 1000, 1000, 4)) if (i + 1) == 1: np.save(out_path + "/out_Y.npy", black_board_Y) np.save(out_path + "/out_X.npy", black_board_X) with self.train_summary_writer.as_default(): ########################## black_board_Y = tf.reshape(tf.constant(black_board_Y, dtype=tf.float32), shape=[1, 128, 128, 1]) black_board_X = tf.reshape(tf.constant(black_board_X, dtype=tf.float32), shape=[1, 128, 128, 1]) black_board_rY = tf.reshape(tf.constant(black_board_rY, dtype=tf.float32), shape=[1, 128, 128, 1]) black_board_rX = tf.reshape(tf.constant(black_board_rX, dtype=tf.float32), shape=[1, 128, 128, 1]) self.m_psnr_X2Y( tf.image.psnr(black_board_Y, black_board_rY, 1.0, name=None)) self.m_psnr_Y2X( tf.image.psnr(black_board_X, black_board_rX, 1.0, name=None)) self.m_ssim_X2Y( tf.image.ssim(black_board_Y, black_board_rY, 1, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)) self.m_ssim_Y2X( tf.image.ssim(black_board_X, black_board_rX, 1, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)) tf.summary.scalar('test_psnr_y', self.m_psnr_X2Y.result(), step=step) tf.summary.scalar('test_psnr_x', self.m_psnr_Y2X.result(), step=step) tf.summary.scalar('test_ssim_y', self.m_ssim_X2Y.result(), step=step) tf.summary.scalar('test_ssim_x', self.m_ssim_Y2X.result(), step=step) tf.summary.image("img", img, step=step) dx1, dy1 = tf.image.image_gradients(black_board_Y) dx2, dy2 = tf.image.image_gradients(black_board_rY) dx_mean = tf.reduce_mean(tf.math.abs(dx1 - dx2)) dy_mean = tf.reduce_mean(tf.math.abs(dy1 - dy2)) IG_y = dy_mean + dx_mean tf.summary.scalar('IG_y', IG_y, step=step) dx1, dy1 = tf.image.image_gradients(black_board_X) dx2, dy2 = tf.image.image_gradients(black_board_rX) dx_mean = tf.reduce_mean(tf.math.abs(dx1 - dx2)) dy_mean = tf.reduce_mean(tf.math.abs(dy1 - dy2)) IG_x = dy_mean + dx_mean tf.summary.scalar('IG_x', IG_x, step=step) tf.summary.image("img", img, step=step) result_buf.append([ i + 1, self.m_psnr_X2Y.result().numpy(), self.m_psnr_Y2X.result().numpy(), self.m_ssim_X2Y.result().numpy(), self.m_ssim_Y2X.result().numpy(), IG_y.numpy(), IG_x.numpy() ]) ########################## self.m_psnr_X2Y.reset_states() self.m_psnr_Y2X.reset_states() self.m_ssim_X2Y.reset_states() self.m_ssim_Y2X.reset_states() import csv headers = [ 'Instance', 'test_psnr_y', 'test_psnr_x', 'test_ssim_y', 'test_ssim_x', 'IG_y', 'IG_x' ] rows = result_buf with open(out_path + '/result' + str(index) + '.csv', 'w') as f: f_csv = csv.writer(f) f_csv.writerow(headers) f_csv.writerows(rows)
class DCGAN(tf.keras.Model): """ 模型只负责给定训练集和测试(验证)集后的操作 """ def __init__(self, train_set, test_set, loss_name="Vanilla", mixed_precision=False, learning_rate=2e-4, tmp_path=None, out_path=None): super(DCGAN, self).__init__() #接收数据集和相关参数 self.train_set = train_set self.test_set = test_set self.tmp_path = tmp_path self.out_path = out_path #定义模型 self.G = networks.Generator(name="G") if loss_name in ["WGAN-SN", "WGAN-GP-SN"]: self.D = networks.Discriminator(name="If_is_real", use_sigmoid=False, sn=True) self.loss_name = loss_name[:-3] elif loss_name in ["WGAN", "WGAN-GP"]: self.D = networks.Discriminator(name="If_is_real", use_sigmoid=False, sn=False) self.loss_name = loss_name elif loss_name in ["Vanilla", "LSGAN"]: self.D = networks.Discriminator(name="If_is_real", use_sigmoid=True, sn=False) self.loss_name = loss_name else: raise ValueError("Do not support the loss " + loss_name) self.model_list = [self.G, self.D] #定义损失函数 优化器 记录等 self.gan_loss = GanLoss(self.loss_name) self.optimizers_list = self.optimizers_config( mixed_precision=mixed_precision, learning_rate=learning_rate) self.mixed_precision = mixed_precision self.matrics_list = self.matrics_config() self.checkpoint_config() self.get_seed() def build(self, input_shape_G, input_shape_D): """ input_shape必须切片 因为在底层会被当做各层的输出shape而被改动 """ self.G.build(input_shape=input_shape_G[:]) #G X->Y self.D.build(input_shape=input_shape_D[:]) #D Y or not Y self.built = True def optimizers_config(self, mixed_precision=False, learning_rate=2e-4): self.G_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9) self.D_optimizer = Adam(learning_rate=1e-4, beta_1=0.0, beta_2=0.9) if mixed_precision: self.G_optimizer = self.G_optimizer.get_mixed_precision() self.D_optimizer = self.D_optimizer.get_mixed_precision() return [self.G_optimizer, self.D_optimizer] def matrics_config(self): current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") train_logdir = self.tmp_path + "/logs/" + current_time self.train_summary_writer = tf.summary.create_file_writer(train_logdir) return [] # return None def checkpoint_config(self): self.ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=self.optimizers_list, model=self.model_list, dataset=self.train_set) self.manager = tf.train.CheckpointManager(self.ckpt, self.tmp_path + '/tf_ckpts', max_to_keep=3) def pix_gradient(self, x): x = tf.reshape(x, shape=[1, 64, 64, 1]) #在各batch和通道上进行像素梯度 对2D单通道而言其实没必要reshape dx, dy = tf.image.image_gradients(x) return dx, dy @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),\ tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),\ tf.TensorSpec(shape=[4],dtype=tf.int32),\ tf.TensorSpec(shape=[1],dtype=tf.uint32)]) def train_step_D(self, trainX, trainY, y_shape, step): with tf.GradientTape(persistent=True) as D_tape: GeneratedY = self.G(trainX) D_real_out = self.D(trainY) D_fake_out = self.D(GeneratedY) e = tf.random.uniform(shape=y_shape, minval=0.0, maxval=1.0) mid_Y = e * trainY + (1 - e) * GeneratedY with tf.GradientTape() as GP: GP.watch(mid_Y) inner_loss = self.D(mid_Y) penalty = GP.gradient(inner_loss, mid_Y) # print("penalty",penalty.shape) # penalty_norm = 10.0*tf.math.square(tf.maximum(tf.norm(penalty,ord='euclidean'),1.0)-1) # penalty_norm = 10.0 * tf.math.square( tf.norm(tf.reshape(penalty, [y_shape[0], -1]), axis=-1, ord=2) - 1) #这是按照算法愿意 # print("penalty_norm",penalty_norm.shape) D_loss = self.gan_loss.DiscriminatorLoss( D_real_out, D_fake_out) + tf.reduce_mean(penalty_norm) if self.mixed_precision: scaled_D_loss = self.D_optimizer.get_scaled_loss(D_loss) if self.mixed_precision: scaled_gradients_of_D = D_tape.gradient(scaled_D_loss, self.D.trainable_variables) gradients_of_D = self.D_optimizer.get_unscaled_gradients( scaled_gradients_of_D) else: gradients_of_D = D_tape.gradient(D_loss, self.D.trainable_variables) self.D_optimizer.apply_gradients( zip(gradients_of_D, self.D.trainable_variables)) return D_loss @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),\ tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),\ tf.TensorSpec(shape=[4],dtype=tf.int32),\ tf.TensorSpec(shape=[1],dtype=tf.uint32)]) def train_step_G(self, trainX, trainY, y_shape, step): with tf.GradientTape(persistent=True) as G_tape: GeneratedY = self.G(trainX) # Dy_real_out = self.Dy(trainY) D_fake_out = self.D(GeneratedY) G_loss = self.gan_loss.GeneratorLoss(D_fake_out) if self.mixed_precision: scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss) if self.mixed_precision: scaled_gradients_of_G = G_tape.gradient(scaled_G_loss, self.G.trainable_variables) gradients_of_G = self.G_optimizer.get_unscaled_gradients( scaled_gradients_of_G) else: gradients_of_G = G_tape.gradient(G_loss, self.G.trainable_variables) self.G_optimizer.apply_gradients( zip(gradients_of_G, self.G.trainable_variables)) return G_loss @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),\ tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),\ tf.TensorSpec(shape=[4],dtype=tf.int32),\ tf.TensorSpec(shape=[1],dtype=tf.uint32)]) def train_step(self, trainX, trainY, y_shape, step): with tf.GradientTape(persistent=True) as gan_type: GeneratedY = self.G(trainX) D_real_out = self.D(trainY) D_fake_out = self.D(GeneratedY) D_loss = self.gan_loss.DiscriminatorLoss(D_real_out, D_fake_out) G_loss = self.gan_loss.GeneratorLoss(D_fake_out) if self.mixed_precision: scaled_D_loss = self.D_optimizer.get_scaled_loss(D_loss) scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss) if self.mixed_precision: scaled_gradients_of_D = gan_type.gradient( scaled_D_loss, self.D.trainable_variables) scaled_gradients_of_G = gan_type.gradient( scaled_G_loss, self.G.trainable_variables) gradients_of_D = self.D_optimizer.get_unscaled_gradients( scaled_gradients_of_D) gradients_of_G = self.G_optimizer.get_unscaled_gradients( scaled_gradients_of_G) else: gradients_of_D = gan_type.gradient(D_loss, self.D.trainable_variables) gradients_of_G = gan_type.gradient(G_loss, self.G.trainable_variables) self.D_optimizer.apply_gradients( zip(gradients_of_D, self.D.trainable_variables)) self.G_optimizer.apply_gradients( zip(gradients_of_G, self.G.trainable_variables)) return D_loss, G_loss def train(self, epoches): self.ckpt.restore(self.manager.latest_checkpoint) for _ in range(epoches): start = time.time() for trainX, trainY in self.train_set: self.ckpt.step.assign_add(1) step = int(self.ckpt.step) if self.loss_name in ["WGAN", "WGAN-GP"]: for __ in range(5): D_loss = self.train_step_D( trainX, trainY, tf.constant([trainY.shape[0], 1, 1, 1], shape=[4], dtype=tf.int32), tf.constant(step, shape=[1], dtype=tf.uint32)) for __ in range(1): G_loss = self.train_step_G( trainX, trainY, tf.constant([trainY.shape[0], 1, 1, 1], shape=[4], dtype=tf.int32), tf.constant(step, shape=[1], dtype=tf.uint32)) elif self.loss_name in ["Vanilla", "LSGAN"]: D_loss, G_loss = self.train_step( trainX, trainY, tf.constant([trainY.shape[0], 1, 1, 1], shape=[4], dtype=tf.int32), tf.constant(step, shape=[1], dtype=tf.uint32)) else: raise ValueError("Inner Error") if step % 100 == 0: save_path = self.manager.save() print("Saved checkpoint for step {}: {}".format( step, save_path)) self.G.save_weights(self.tmp_path + '/weights_saved/G.ckpt') self.D.save_weights(self.tmp_path + '/weights_saved/D.ckpt') self.wirte_summary(step=step, seed=self.seed, G=self.G, G_loss=G_loss, D_loss=D_loss, out_path=self.out_path) print('Time to next 100 step {} is {} sec'.format( step, time.time() - start)) start = time.time() def test(self, take_nums): out_path = self.out_path + "/test" import os if not os.path.exists(out_path): os.makedirs(out_path) self.ckpt.restore(self.manager.latest_checkpoint) seed_get = iter(self.test_set) for take in range(take_nums): plt.figure(figsize=(10, 10)) #图片大一点才可以承载像素 for i in range(100): single_seed = next(seed_get) GeneratedY = self.G(single_seed, training=False) plt.subplot(10, 10, (i + 1)) plt.imshow(GeneratedY[0, :, :, 0], cmap='gray') plt.axis('off') plt.savefig(out_path + '/image_at_{}.png'.format(take)) plt.close() def get_seed(self): self.seed = [] seed_get = iter(self.test_set) for _ in range(100): seed = next(seed_get) self.seed.append(seed) def wirte_summary(self, step, seed, G, G_loss, D_loss, out_path): plt.figure(figsize=(10, 10)) #图片大一点才可以承载像素 for i, single_seed in enumerate(seed): GeneratedY = G(single_seed, training=False) plt.subplot(10, 10, (i + 1)) plt.imshow(GeneratedY[0, :, :, 0], cmap='gray') plt.axis('off') plt.savefig(out_path + '/image_at_{}.png'.format(step)) plt.close() img = Image.open(out_path + '/image_at_{}.png'.format(step)) img = tf.reshape(np.array(img), shape=(1, 1000, 1000, 4)) with self.train_summary_writer.as_default(): ########################## tf.summary.scalar('G_loss', G_loss, step=step) tf.summary.scalar('D_loss', D_loss, step=step) tf.summary.image("img", img, step=step)
class CycleGAN(tf.keras.Model): """ 模型只负责给定训练集和测试(验证)集后的操作 """ def __init__(self, train_set, test_set, loss_name="WGAN-GP", mixed_precision=False, learning_rate=2e-4, tmp_path=None, out_path=None): super(CycleGAN,self).__init__() #接收数据集和相关参数 self.train_set = train_set self.test_set = test_set self.tmp_path = tmp_path self.out_path = out_path #定义模型 self.G = networks.Generator(name="G_X2Y") self.F = networks.Generator(name="G_Y2X") if loss_name in ["WGAN-SN","WGAN-GP-SN"]: self.Dy = networks.Discriminator(name="If_is_real_Y",use_sigmoid=False,sn=True) self.Dx = networks.Discriminator(name="If_is_real_X",use_sigmoid=False,sn=True) self.loss_name = loss_name[:-3] elif loss_name in ["WGAN","WGAN-GP"]: self.Dy = networks.Discriminator(name="If_is_real_Y",use_sigmoid=False,sn=False) self.Dx = networks.Discriminator(name="If_is_real_X",use_sigmoid=False,sn=False) self.loss_name = loss_name elif loss_name in ["Vanilla","LSGAN"]: self.Dy = networks.Discriminator(name="If_is_real_Y",use_sigmoid=True,sn=False) self.Dx = networks.Discriminator(name="If_is_real_X",use_sigmoid=True,sn=False) self.loss_name = loss_name else: raise ValueError("Do not support the loss "+loss_name) self.model_list=[self.G,self.F,self.Dy,self.Dx] #定义损失函数 优化器 记录等 self.gan_loss = GanLoss(self.loss_name) self.optimizers_list = self.optimizers_config(mixed_precision=mixed_precision,learning_rate=learning_rate) self.mixed_precision = mixed_precision self.matrics_list = self.matrics_config() self.checkpoint_config() self.get_seed() def build(self,X_shape,Y_shape): """ input_shape必须切片 因为在底层会被当做各层的输出shape而被改动 """ self.G.build(input_shape=X_shape[:])#G X->Y self.Dy.build(input_shape=Y_shape[:])#Dy Y or != Y self.F.build(input_shape=Y_shape[:])#F Y->X self.Dx.build(input_shape=X_shape[:])#Dx X or != X self.built = True def optimizers_config(self,mixed_precision=False,learning_rate=2e-4): self.G_optimizer = Adam(2e-4) self.Dy_optimizer = Adam(2e-4) self.F_optimizer = Adam(2e-4) self.Dx_optimizer = Adam(2e-4) if mixed_precision: self.G_optimizer=self.G_optimizer.get_mixed_precision() self.Dy_optimizer=self.Dy_optimizer.get_mixed_precision() self.F_optimizer=self.F_optimizer.get_mixed_precision() self.Dx_optimizer=self.Dx_optimizer.get_mixed_precision() return [self.G_optimizer,self.Dy_optimizer,self.F_optimizer,self.Dx_optimizer] def matrics_config(self): current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") train_logdir = self.tmp_path+"/logs/" + current_time self.train_summary_writer = tf.summary.create_file_writer(train_logdir) return [] def checkpoint_config(self): self.ckpt = tf.train.Checkpoint(step=tf.Variable(1),optimizer=self.optimizers_list,model=self.model_list,dataset=self.train_set) self.manager = tf.train.CheckpointManager(self.ckpt,self.tmp_path+'/tf_ckpts', max_to_keep=3) @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),tf.TensorSpec(shape=[1],dtype=tf.uint32)]) def train_step_D(self,trainX,trainY,step): with tf.GradientTape(persistent=True) as D_tape: GeneratedY = self.G(trainX) Dy_real_out = self.Dy(trainY) Dy_fake_out = self.Dy(GeneratedY) GeneratedX = self.F(trainY) Dx_real_out = self.Dx(trainX) Dx_fake_out = self.Dx(GeneratedX) e = tf.random.uniform((trainY.shape[0],1,1,1),0.0,1.0) mid_Y = e*trainY+(1-e)*GeneratedY with tf.GradientTape() as gradient_penaltyY: gradient_penaltyY.watch(mid_Y) inner_loss = self.Dy(mid_Y) penalty = gradient_penaltyY.gradient(inner_loss,mid_Y) penalty_normY = 10.0*tf.math.square(tf.maximum(tf.norm(penalty,ord='euclidean'),1.0)-1)# 这是我自己认为的 因为只有梯度大于1的才需要优化哇 e = tf.random.uniform((trainX.shape[0],1,1,1),0.0,1.0) mid_X = e*trainX+(1-e)*GeneratedX with tf.GradientTape() as gradient_penaltyX: gradient_penaltyX.watch(mid_X) inner_loss = self.Dx(mid_X) penalty = gradient_penaltyX.gradient(inner_loss,mid_X) penalty_normX = 10.0*tf.math.square(tf.maximum(tf.norm(penalty,ord='euclidean'),1.0)-1) Dy_loss = self.gan_loss.DiscriminatorLoss(Dy_real_out,Dy_fake_out)+tf.reduce_mean(penalty_normY) Dx_loss = self.gan_loss.DiscriminatorLoss(Dx_real_out,Dx_fake_out)+tf.reduce_mean(penalty_normX) if self.mixed_precision: scaled_Dy_loss = self.Dy_optimizer.get_scaled_loss(Dy_loss) scaled_Dx_loss = self.Dx_optimizer.get_scaled_loss(Dx_loss) if self.mixed_precision: scaled_gradients_of_Dy=D_tape.gradient(scaled_Dy_loss,self.Dy.trainable_variables) scaled_gradients_of_Dx=D_tape.gradient(scaled_Dx_loss,self.Dx.trainable_variables) gradients_of_Dy = self.Dy_optimizer.get_unscaled_gradients(scaled_gradients_of_Dy) gradients_of_Dx = self.Dx_optimizer.get_unscaled_gradients(scaled_gradients_of_Dx) else: gradients_of_Dy = D_tape.gradient(Dy_loss,self.Dy.trainable_variables) gradients_of_Dx = D_tape.gradient(Dx_loss,self.Dx.trainable_variables) self.Dy_optimizer.apply_gradients(zip(gradients_of_Dy,self.Dy.trainable_variables)) self.Dx_optimizer.apply_gradients(zip(gradients_of_Dx,self.Dx.trainable_variables)) return Dy_loss,Dx_loss @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),tf.TensorSpec(shape=[1],dtype=tf.uint32)]) def train_step_G(self,trainX,trainY,step): with tf.GradientTape(persistent=True) as G_tape: GeneratedY = self.G(trainX) # Dy_real_out = self.Dy(trainY) Dy_fake_out = self.Dy(GeneratedY) GeneratedX = self.F(trainY) # Dx_real_out = self.Dx(trainX) Dx_fake_out = self.Dx(GeneratedX) cycle_consistent_loss_X2Y = tf.reduce_mean(tf.abs(self.F(GeneratedY)-trainX)) cycle_consistent_loss_Y2X = tf.reduce_mean(tf.abs(self.G(GeneratedX)-trainY)) cycle_consistent = cycle_consistent_loss_X2Y+cycle_consistent_loss_Y2X if step>=0:#先不进行像素梯度和重建损失的使用 cycle_l = 10.0 else: cycle_l = 10.0 G_loss = self.gan_loss.GeneratorLoss(Dy_fake_out)+cycle_l*(cycle_consistent) F_loss = self.gan_loss.GeneratorLoss(Dx_fake_out)+cycle_l*(cycle_consistent) if self.mixed_precision: scaled_G_loss = self.G_optimizer.get_scaled_loss(G_loss) scaled_F_loss = self.F_optimizer.get_scaled_loss(F_loss) if self.mixed_precision: scaled_gradients_of_G=G_tape.gradient(scaled_G_loss,self.G.trainable_variables) scaled_gradients_of_F=G_tape.gradient(scaled_F_loss,self.F.trainable_variables) gradients_of_G = self.G_optimizer.get_unscaled_gradients(scaled_gradients_of_G) gradients_of_F = self.F_optimizer.get_unscaled_gradients(scaled_gradients_of_F) else: gradients_of_G = G_tape.gradient(G_loss,self.G.trainable_variables) gradients_of_F = G_tape.gradient(F_loss,self.F.trainable_variables) self.G_optimizer.apply_gradients(zip(gradients_of_G,self.G.trainable_variables)) self.F_optimizer.apply_gradients(zip(gradients_of_F,self.F.trainable_variables)) return G_loss,F_loss @tf.function(input_signature=[tf.TensorSpec(shape=global_input_X_shape,dtype=tf.float32),tf.TensorSpec(shape=global_input_Y_shape,dtype=tf.float32),tf.TensorSpec(shape=[1],dtype=tf.uint32)]) def train_step(self,trainX,trainY,step): with tf.GradientTape(persistent=True) as cycle_type: GeneratedY = self.G(trainX) Dy_real_out = self.Dy(trainY) Dy_fake_out = self.Dy(GeneratedY) GeneratedX = self.F(trainY) Dx_real_out = self.Dx(trainX) Dx_fake_out = self.Dx(GeneratedX) cycle_consistent_loss_X2Y = tf.reduce_mean(tf.abs(self.F(GeneratedY)-trainX)) cycle_consistent_loss_Y2X = tf.reduce_mean(tf.abs(self.G(GeneratedX)-trainY)) cycle_consistent = cycle_consistent_loss_X2Y+cycle_consistent_loss_Y2X if step>=0:#先不进行像素梯度和重建损失的使用 cycle_l = 10.0 else: cycle_l = 10.0 Dy_loss = self.gan_loss.DiscriminatorLoss(Dy_real_out,Dy_fake_out) Dx_loss = self.gan_loss.DiscriminatorLoss(Dx_real_out,Dx_fake_out) G_loss = self.gan_loss.GeneratorLoss(Dy_fake_out)+cycle_l*(cycle_consistent) F_loss = self.gan_loss.GeneratorLoss(Dx_fake_out)+cycle_l*(cycle_consistent) gradients_of_Dy = cycle_type.gradient(Dy_loss,self.Dy.trainable_variables) gradients_of_Dx = cycle_type.gradient(Dx_loss,self.Dx.trainable_variables) gradients_of_G = cycle_type.gradient(G_loss,self.G.trainable_variables) gradients_of_F = cycle_type.gradient(F_loss,self.F.trainable_variables) self.Dy_optimizer.apply_gradients(zip(gradients_of_Dy,self.Dy.trainable_variables)) self.Dx_optimizer.apply_gradients(zip(gradients_of_Dx,self.Dx.trainable_variables)) self.G_optimizer.apply_gradients(zip(gradients_of_G,self.G.trainable_variables)) self.F_optimizer.apply_gradients(zip(gradients_of_F,self.F.trainable_variables)) return G_loss,Dy_loss,F_loss,Dx_loss def train(self,epoches): self.ckpt.restore(self.manager.latest_checkpoint) for _ in range(epoches): start = time.time() for trainX,trainY in self.train_set: self.ckpt.step.assign_add(1) step = int(self.ckpt.step) if self.loss_name in ["WGAN","WGAN-GP"]: for __ in range(1): Dy_loss,Dx_loss = self.train_step_D(trainX,trainY,tf.constant(step,shape=[1],dtype=tf.uint32)) for __ in range(1): G_loss,F_loss = self.train_step_G(trainX,trainY,tf.constant(step,shape=[1],dtype=tf.uint32)) elif self.loss_name in ["Vanilla","LSGAN"]: G_loss,Dy_loss,F_loss,Dx_loss = self.train_step(trainX,trainY,tf.constant(step,shape=[1],dtype=tf.uint32)) else: raise ValueError("Inner Error") if step % 100 == 0: save_path = self.manager.save() print("Saved checkpoint for step {}: {}".format(step,save_path)) self.G.save_weights(self.tmp_path+'/weights_saved/G.ckpt') self.Dy.save_weights(self.tmp_path+'/weights_saved/Dy.ckpt') self.F.save_weights(self.tmp_path+'/weights_saved/F.ckpt') self.Dx.save_weights(self.tmp_path+'/weights_saved/Dx.ckpt') self.wirte_summary(step=step, seed=self.seed, G=self.G, F=self.F, G_loss=G_loss, Dy_loss=Dy_loss, F_loss=F_loss, Dx_loss=Dx_loss, out_path=self.out_path) print ('Time to next 100 step {} is {} sec'.format(step,time.time()-start)) start = time.time() def test(self,take_nums): self.ckpt.restore(self.manager.latest_checkpoint) for i,(testX,testY) in enumerate(self.test_set): GeneratedY = self.G(testX) GeneratedX = self.F(testY) plt.figure(figsize=(5,5))#图片大一点才可以承载像素 plt.subplot(2,2,1) plt.title('real X') plt.imshow(testX[0,:,:,:]) plt.axis('off') plt.subplot(2,2,2) plt.title('fake Y') plt.imshow(GeneratedY[0,:,:,:]) plt.axis('off') plt.subplot(2,2,3) plt.title('fake X') plt.imshow(GeneratedX[0,:,:,:]) plt.axis('off') plt.subplot(2,2,4) plt.title('real Y') plt.imshow(testY[0,:,:,:]) plt.axis('off') plt.savefig(self.out_path+'/test/{}.png'.format(i)) plt.close() img = Image.open(self.out_path+'/test/{}.png'.format(i)) img = tf.reshape(np.array(img),shape=(1,500,500,4)) def get_seed(self): seed_get = iter(self.test_set) seed = next(seed_get) print(seed[0].shape,seed[1].dtype) plt.imshow(seed[0][0,:,:,:]) plt.show() plt.imshow(seed[1][0,:,:,:]) plt.show() self.seed = seed def wirte_summary(self,step,seed,G,F,G_loss,Dy_loss,F_loss,Dx_loss,out_path): testX,testY= seed GeneratedY = G(testX) GeneratedX = F(testY) plt.figure(figsize=(5,5))#图片大一点才可以承载像素 plt.subplot(2,2,1) plt.title('real X') plt.imshow(testX[0,:,:,:]) plt.axis('off') plt.subplot(2,2,2) plt.title('fake Y') plt.imshow(GeneratedY[0,:,:,:]) plt.axis('off') plt.subplot(2,2,3) plt.title('fake X') plt.imshow(GeneratedX[0,:,:,:]) plt.axis('off') plt.subplot(2,2,4) plt.title('real Y') plt.imshow(testY[0,:,:,:]) plt.axis('off') plt.savefig(out_path+'/image_at_{}.png'.format(step)) plt.close() img = Image.open(out_path+'/image_at_{}.png'.format(step)) img = tf.reshape(np.array(img),shape=(1,500,500,4)) with self.train_summary_writer.as_default(): tf.summary.scalar('G_loss',G_loss,step=step) tf.summary.scalar('Dy_loss',Dy_loss,step=step) tf.summary.scalar('F_loss',F_loss,step=step) tf.summary.scalar('Dx_loss',Dx_loss,step=step) tf.summary.image("img",img,step=step)