コード例 #1
0
    def construct_model(self, crop_size, load_size):
        self.A_B_dataset, len_dataset = data.make_zip_dataset(
            self.A_img_paths,
            self.B_img_paths,
            self.batch_size,
            load_size,
            crop_size,
            training=True,
            repeat=False,
            is_gray_scale=(self.color_depth == 1))
        self.len_dataset = len_dataset

        self.A2B_pool = data.ItemPool(self.pool_size)
        self.B2A_pool = data.ItemPool(self.pool_size)
        A_img_paths_test = py.glob(
            py.join(self.datasets_dir, self.dataset, 'testA'),
            '*.{}'.format(self.image_ext))
        B_img_paths_test = py.glob(
            py.join(self.datasets_dir, self.dataset, 'testB'),
            '*.{}'.format(self.image_ext))
        A_B_dataset_test, _ = data.make_zip_dataset(
            A_img_paths_test,
            B_img_paths_test,
            self.batch_size,
            load_size,
            crop_size,
            training=False,
            repeat=True,
            is_gray_scale=(self.color_depth == 1))
        self.test_iter = iter(A_B_dataset_test)
        self.G_A2B = module.ResnetGenerator(input_shape=(crop_size, crop_size,
                                                         self.color_depth),
                                            output_channels=self.color_depth)
        self.G_B2A = module.ResnetGenerator(input_shape=(crop_size, crop_size,
                                                         self.color_depth),
                                            output_channels=self.color_depth)
        self.D_A = module.ConvDiscriminator(input_shape=(crop_size, crop_size,
                                                         self.color_depth))
        self.D_B = module.ConvDiscriminator(input_shape=(crop_size, crop_size,
                                                         self.color_depth))
        self.d_loss_fn, self.g_loss_fn = gan.get_adversarial_losses_fn(
            self.adversarial_loss_mode)
        self.cycle_loss_fn = tf.losses.MeanAbsoluteError()
        self.identity_loss_fn = tf.losses.MeanAbsoluteError()
        self.G_lr_scheduler = module.LinearDecay(
            self.lr, self.epochs * self.len_dataset,
            self.epoch_decay * self.len_dataset)
        self.D_lr_scheduler = module.LinearDecay(
            self.lr, self.epochs * self.len_dataset,
            self.epoch_decay * self.len_dataset)
        self.G_optimizer = keras.optimizers.Adam(
            learning_rate=self.G_lr_scheduler, beta_1=self.beta_1)
        self.D_optimizer = keras.optimizers.Adam(
            learning_rate=self.D_lr_scheduler, beta_1=self.beta_1)
コード例 #2
0
# setup the normalization function for discriminator
if args.gradient_penalty_mode == 'none':
    d_norm = 'batch_norm'
if args.gradient_penalty_mode in ['dragan', 'wgan-gp']:  # cannot use batch normalization with gradient penalty
    # TODO(Lynn)
    # Layer normalization is more stable than instance normalization here,
    # but instance normalization works in other implementations.
    # Please tell me if you find out the cause.
    d_norm = 'layer_norm'

# networks
# Comment by K.C:
# the following commands set the structure of a G model
G = module.ConvGenerator(input_shape=(1, 1, args.z_dim), output_channels=shape[-1], n_upsamplings=n_G_upsamplings, name='G_%s' % args.dataset)
D = module.ConvDiscriminator(input_shape=shape, n_downsamplings=n_D_downsamplings, norm=d_norm, name='D_%s' % args.dataset)

py.mkdir('%s/summaries' %output_dir)
keras.utils.plot_model(G,'%s/summaries/convGenerator.png' % output_dir, show_shapes=True)
keras.utils.plot_model(D,'%s/summaries/convDiscriminator.png' % output_dir, show_shapes=True)
G.summary()
D.summary()

# adversarial_loss_functions
d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(args.adversarial_loss_mode)

G_optimizer = keras.optimizers.Adam(learning_rate=args.lr, beta_1=args.beta_1)
D_optimizer = keras.optimizers.Adam(learning_rate=args.lr, beta_1=args.beta_1)


# ==============================================================================
コード例 #3
0
G_downsamplings = args.n_downsamplings
D_downsamplings = min(args.n_downsamplings + 1, 4)

G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=G_downsamplings,
                               n_blocks=args.n_blocks,
                               norm=args.norm)
G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=G_downsamplings,
                               n_blocks=args.n_blocks,
                               norm=args.norm)

D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=D_downsamplings,
                               norm=args.norm)
D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=D_downsamplings,
                               norm=args.norm)

d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(
    args.adversarial_loss_mode)
cycle_loss_fn = tf.losses.MeanAbsoluteError()
identity_loss_fn = tf.losses.MeanAbsoluteError()

G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
コード例 #4
0
A_B_dataset_test, _ = data.make_zip_dataset(A_img_paths_test,
                                            B_img_paths_test,
                                            args.batch_size,
                                            args.load_size,
                                            args.crop_size,
                                            training=False,
                                            repeat=True)

# ==============================================================================
# =                                   models                                   =
# ==============================================================================

G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3))
G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3))

D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3))
D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3))

d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(
    args.adversarial_loss_mode)
cycle_loss_fn = tf.losses.MeanAbsoluteError()
identity_loss_fn = tf.losses.MeanAbsoluteError()

G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
G_optimizer = keras.optimizers.Adam(learning_rate=G_lr_scheduler,
                                    beta_1=args.beta_1)
D_optimizer = keras.optimizers.Adam(learning_rate=D_lr_scheduler,
                                    beta_1=args.beta_1)
コード例 #5
0
        repeat=False,
        shuffle=False)

# ==============================================================================
# =                                   models                                   =
# ==============================================================================
with tf.device('/device:GPU:%d' % GPU[1]):
    G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size,
                                                3),
                                   dim=args.kernels_num)
    G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size,
                                                3),
                                   dim=args.kernels_num)

    D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size,
                                                3),
                                   dim=args.kernels_num)
    D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size,
                                                3),
                                   dim=args.kernels_num)
    d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(
        args.adversarial_loss_mode)
    cycle_loss_fn = tf.losses.MeanAbsoluteError()
    identity_loss_fn = tf.losses.MeanAbsoluteError()
    G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                        args.epoch_decay * len_dataset)
    D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                        args.epoch_decay * len_dataset)
    G_optimizer = keras.optimizers.Adam(learning_rate=G_lr_scheduler,
                                        beta_1=args.beta_1)
    D_optimizer = keras.optimizers.Adam(learning_rate=D_lr_scheduler,
コード例 #6
0
# ==============================================================================
# =                                   model                                    =
# ==============================================================================

# setup the normalization function for discriminator
if args.gradient_penalty_mode == 'none':
    d_norm = 'batch_norm'
else:  # cannot use batch normalization with gradient penalty
    d_norm = args.gradient_penalty_d_norm

# networks
G = module.ConvGenerator(args.z_dim, shape[-1],
                         n_upsamplings=n_G_upsamplings).to(device)
D = module.ConvDiscriminator(shape[-1],
                             n_downsamplings=n_D_downsamplings,
                             norm=d_norm).to(device)
print(G)
print(D)

# adversarial_loss_functions
d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(
    args.adversarial_loss_mode)

# optimizer
G_optimizer = torch.optim.Adam(G.parameters(),
                               lr=args.lr,
                               betas=(args.beta_1, 0.999))
D_optimizer = torch.optim.Adam(D.parameters(),
                               lr=args.lr,
                               betas=(args.beta_1, 0.999))
コード例 #7
0
# =                                   models                                   =
# ==============================================================================

G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               n_downsamplings=args.G_n_downsamplings,
                               n_blocks=args.G_n_residual_blocks,
                               dim=args.G_n_filters,
                               norm=args.norm_type)
G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               n_downsamplings=args.G_n_downsamplings,
                               n_blocks=args.G_n_residual_blocks,
                               dim=args.G_n_filters,
                               norm=args.norm_type)

D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               n_downsamplings=args.D_n_downsamplings,
                               dim=args.D_n_filters,
                               norm=args.norm_type)
D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3),
                               n_downsamplings=args.D_n_downsamplings,
                               dim=args.D_n_filters,
                               norm=args.norm_type)

d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(
    args.adversarial_loss_mode)
cycle_loss_fn = tf.losses.MeanAbsoluteError()
identity_loss_fn = tf.losses.MeanAbsoluteError()

G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                    args.epoch_decay * len_dataset)
コード例 #8
0
len_dataset = B_length
#print('b list\n',B_list[0][0].squeeze().shape)

A2B_pool = data.ItemPool(apool_size)
#B2A_pool = data.ItemPool(a.pool_size)

print('session starts \n')
print('B', B_length)

# ==============================================================================
# =                                   models                                   =
# ==============================================================================

G = module.ResnetGenerator(input_shape=(acrop_size, acrop_size, 6))

D = module.ConvDiscriminator(input_shape=(acrop_size, acrop_size, 6))
#print('generator')
#print(G.summary())

#print('discriminator')
#print(D.summary())

d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(aadversarial_loss_mode)
cycle_loss_fn = tf.losses.MeanAbsoluteError()
identity_loss_fn = tf.losses.MeanAbsoluteError()

G_lr_scheduler = module.LinearDecay(alr, aepochs * len_dataset,
                                    aepoch_decay * len_dataset)
D_lr_scheduler = module.LinearDecay(alr, aepochs * len_dataset,
                                    aepoch_decay * len_dataset)
G_optimizer = keras.optimizers.Adam(learning_rate=G_lr_scheduler,