예제 #1
0
    def construct_model(self, crop_size, load_size):
        self.A_B_dataset, len_dataset = data.make_zip_dataset(
            self.A_img_paths,
            self.B_img_paths,
            self.batch_size,
            load_size,
            crop_size,
            training=True,
            repeat=False,
            is_gray_scale=(self.color_depth == 1))
        self.len_dataset = len_dataset

        self.A2B_pool = data.ItemPool(self.pool_size)
        self.B2A_pool = data.ItemPool(self.pool_size)
        A_img_paths_test = py.glob(
            py.join(self.datasets_dir, self.dataset, 'testA'),
            '*.{}'.format(self.image_ext))
        B_img_paths_test = py.glob(
            py.join(self.datasets_dir, self.dataset, 'testB'),
            '*.{}'.format(self.image_ext))
        A_B_dataset_test, _ = data.make_zip_dataset(
            A_img_paths_test,
            B_img_paths_test,
            self.batch_size,
            load_size,
            crop_size,
            training=False,
            repeat=True,
            is_gray_scale=(self.color_depth == 1))
        self.test_iter = iter(A_B_dataset_test)
        self.G_A2B = module.ResnetGenerator(input_shape=(crop_size, crop_size,
                                                         self.color_depth),
                                            output_channels=self.color_depth)
        self.G_B2A = module.ResnetGenerator(input_shape=(crop_size, crop_size,
                                                         self.color_depth),
                                            output_channels=self.color_depth)
        self.D_A = module.ConvDiscriminator(input_shape=(crop_size, crop_size,
                                                         self.color_depth))
        self.D_B = module.ConvDiscriminator(input_shape=(crop_size, crop_size,
                                                         self.color_depth))
        self.d_loss_fn, self.g_loss_fn = gan.get_adversarial_losses_fn(
            self.adversarial_loss_mode)
        self.cycle_loss_fn = tf.losses.MeanAbsoluteError()
        self.identity_loss_fn = tf.losses.MeanAbsoluteError()
        self.G_lr_scheduler = module.LinearDecay(
            self.lr, self.epochs * self.len_dataset,
            self.epoch_decay * self.len_dataset)
        self.D_lr_scheduler = module.LinearDecay(
            self.lr, self.epochs * self.len_dataset,
            self.epoch_decay * self.len_dataset)
        self.G_optimizer = keras.optimizers.Adam(
            learning_rate=self.G_lr_scheduler, beta_1=self.beta_1)
        self.D_optimizer = keras.optimizers.Adam(
            learning_rate=self.D_lr_scheduler, beta_1=self.beta_1)
예제 #2
0
    args.batch_size,
    args.load_size,
    args.crop_size,
    training=False,
    repeat=True,
    augmentation_preset=args.augmentation)

# ==============================================================================
# =                                   models                                   =
# ==============================================================================
G_downsamplings = args.n_downsamplings
D_downsamplings = min(args.n_downsamplings + 1, 4)

G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=G_downsamplings,
                               n_blocks=args.n_blocks,
                               norm=args.norm)
G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=G_downsamplings,
                               n_blocks=args.n_blocks,
                               norm=args.norm)

D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=D_downsamplings,
                               norm=args.norm)
D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=D_downsamplings,
예제 #3
0
                           '*.jpg')
B_img_paths_test = py.glob(py.join(args.datasets_dir, args.dataset, 'testB'),
                           '*.jpg')
A_B_dataset_test, _ = data.make_zip_dataset(A_img_paths_test,
                                            B_img_paths_test,
                                            args.batch_size,
                                            args.load_size,
                                            args.crop_size,
                                            training=False,
                                            repeat=True)

# ==============================================================================
# =                                   models                                   =
# ==============================================================================

G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3))
G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3))

D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3))
D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3))

d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(
    args.adversarial_loss_mode)
cycle_loss_fn = tf.losses.MeanAbsoluteError()
identity_loss_fn = tf.losses.MeanAbsoluteError()

G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
G_optimizer = keras.optimizers.Adam(learning_rate=G_lr_scheduler,
예제 #4
0
        B_img_paths_test,
        B_test_labels,
        args.batch_size_triplet,
        args.load_size,
        args.crop_size,
        training=False,
        grayscale=args.grayscale,
        repeat=False,
        shuffle=False)

# ==============================================================================
# =                                   models                                   =
# ==============================================================================
with tf.device('/device:GPU:%d' % GPU[1]):
    G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size,
                                                3),
                                   dim=args.kernels_num)
    G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size,
                                                3),
                                   dim=args.kernels_num)

    D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size,
                                                3),
                                   dim=args.kernels_num)
    D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size,
                                                3),
                                   dim=args.kernels_num)
    d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(
        args.adversarial_loss_mode)
    cycle_loss_fn = tf.losses.MeanAbsoluteError()
    identity_loss_fn = tf.losses.MeanAbsoluteError()
예제 #5
0
    B_img_paths_test,
    args.batch_size,
    args.load_size,
    args.crop_size,
    args.channels,
    training=False,
    repeat=True,
)

# ==============================================================================
# =                                   models                                   =
# ==============================================================================

G_A2B = module.ResnetGenerator(
    input_shape=(args.crop_size, args.crop_size, args.channels),
    output_channels=args.channels,
    n_blocks=args.resnet_blocks,
)
G_B2A = module.ResnetGenerator(
    input_shape=(args.crop_size, args.crop_size, args.channels),
    output_channels=args.channels,
    n_blocks=args.resnet_blocks,
)

D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size,
                                            args.channels))
D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size,
                                            args.channels))

if args.DnCNN is not None:
    DnCNN = keras.models.load_model(args.DnCNN, compile=False)
    A_img_paths_test,
    B_img_paths_test,
    args.batch_size,
    args.load_size,
    args.crop_size,
    training=False,
    repeat=True,
    n_prefetch_batch=args.n_prefetch_batch)

# ==============================================================================
# =                                   models                                   =
# ==============================================================================

G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               n_downsamplings=args.G_n_downsamplings,
                               n_blocks=args.G_n_residual_blocks,
                               dim=args.G_n_filters,
                               norm=args.norm_type)
G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               n_downsamplings=args.G_n_downsamplings,
                               n_blocks=args.G_n_residual_blocks,
                               dim=args.G_n_filters,
                               norm=args.norm_type)

D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               n_downsamplings=args.D_n_downsamplings,
                               dim=args.D_n_filters,
                               norm=args.norm_type)
D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               n_downsamplings=args.D_n_downsamplings,
                               dim=args.D_n_filters,
예제 #7
0
B_length = len(B_list)
len_dataset = B_length
#print('b list\n',B_list[0][0].squeeze().shape)

A2B_pool = data.ItemPool(apool_size)
#B2A_pool = data.ItemPool(a.pool_size)

print('session starts \n')
print('B', B_length)

# ==============================================================================
# =                                   models                                   =
# ==============================================================================

G = module.ResnetGenerator(input_shape=(acrop_size, acrop_size, 6))

D = module.ConvDiscriminator(input_shape=(acrop_size, acrop_size, 6))
#print('generator')
#print(G.summary())

#print('discriminator')
#print(D.summary())

d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(aadversarial_loss_mode)
cycle_loss_fn = tf.losses.MeanAbsoluteError()
identity_loss_fn = tf.losses.MeanAbsoluteError()

G_lr_scheduler = module.LinearDecay(alr, aepochs * len_dataset,
                                    aepoch_decay * len_dataset)
D_lr_scheduler = module.LinearDecay(alr, aepochs * len_dataset,