示例#1
0
    def construct_model(self, crop_size, load_size):
        self.A_B_dataset, len_dataset = data.make_zip_dataset(
            self.A_img_paths,
            self.B_img_paths,
            self.batch_size,
            load_size,
            crop_size,
            training=True,
            repeat=False,
            is_gray_scale=(self.color_depth == 1))
        self.len_dataset = len_dataset

        self.A2B_pool = data.ItemPool(self.pool_size)
        self.B2A_pool = data.ItemPool(self.pool_size)
        A_img_paths_test = py.glob(
            py.join(self.datasets_dir, self.dataset, 'testA'),
            '*.{}'.format(self.image_ext))
        B_img_paths_test = py.glob(
            py.join(self.datasets_dir, self.dataset, 'testB'),
            '*.{}'.format(self.image_ext))
        A_B_dataset_test, _ = data.make_zip_dataset(
            A_img_paths_test,
            B_img_paths_test,
            self.batch_size,
            load_size,
            crop_size,
            training=False,
            repeat=True,
            is_gray_scale=(self.color_depth == 1))
        self.test_iter = iter(A_B_dataset_test)
        self.G_A2B = module.ResnetGenerator(input_shape=(crop_size, crop_size,
                                                         self.color_depth),
                                            output_channels=self.color_depth)
        self.G_B2A = module.ResnetGenerator(input_shape=(crop_size, crop_size,
                                                         self.color_depth),
                                            output_channels=self.color_depth)
        self.D_A = module.ConvDiscriminator(input_shape=(crop_size, crop_size,
                                                         self.color_depth))
        self.D_B = module.ConvDiscriminator(input_shape=(crop_size, crop_size,
                                                         self.color_depth))
        self.d_loss_fn, self.g_loss_fn = gan.get_adversarial_losses_fn(
            self.adversarial_loss_mode)
        self.cycle_loss_fn = tf.losses.MeanAbsoluteError()
        self.identity_loss_fn = tf.losses.MeanAbsoluteError()
        self.G_lr_scheduler = module.LinearDecay(
            self.lr, self.epochs * self.len_dataset,
            self.epoch_decay * self.len_dataset)
        self.D_lr_scheduler = module.LinearDecay(
            self.lr, self.epochs * self.len_dataset,
            self.epoch_decay * self.len_dataset)
        self.G_optimizer = keras.optimizers.Adam(
            learning_rate=self.G_lr_scheduler, beta_1=self.beta_1)
        self.D_optimizer = keras.optimizers.Adam(
            learning_rate=self.D_lr_scheduler, beta_1=self.beta_1)
示例#2
0
D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=D_downsamplings,
                               norm=args.norm)
D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=D_downsamplings,
                               norm=args.norm)

d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(
    args.adversarial_loss_mode)
cycle_loss_fn = tf.losses.MeanAbsoluteError()
identity_loss_fn = tf.losses.MeanAbsoluteError()

G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
G_optimizer = keras.optimizers.Adam(learning_rate=G_lr_scheduler,
                                    beta_1=args.beta_1)
D_optimizer = keras.optimizers.Adam(learning_rate=D_lr_scheduler,
                                    beta_1=args.beta_1)

# ==============================================================================
# =                                 train step                                 =
# ==============================================================================


@tf.function
def train_G(A, B):
    with tf.GradientTape() as t: