def construct_model(self, crop_size, load_size): self.A_B_dataset, len_dataset = data.make_zip_dataset( self.A_img_paths, self.B_img_paths, self.batch_size, load_size, crop_size, training=True, repeat=False, is_gray_scale=(self.color_depth == 1)) self.len_dataset = len_dataset self.A2B_pool = data.ItemPool(self.pool_size) self.B2A_pool = data.ItemPool(self.pool_size) A_img_paths_test = py.glob( py.join(self.datasets_dir, self.dataset, 'testA'), '*.{}'.format(self.image_ext)) B_img_paths_test = py.glob( py.join(self.datasets_dir, self.dataset, 'testB'), '*.{}'.format(self.image_ext)) A_B_dataset_test, _ = data.make_zip_dataset( A_img_paths_test, B_img_paths_test, self.batch_size, load_size, crop_size, training=False, repeat=True, is_gray_scale=(self.color_depth == 1)) self.test_iter = iter(A_B_dataset_test) self.G_A2B = module.ResnetGenerator(input_shape=(crop_size, crop_size, self.color_depth), output_channels=self.color_depth) self.G_B2A = module.ResnetGenerator(input_shape=(crop_size, crop_size, self.color_depth), output_channels=self.color_depth) self.D_A = module.ConvDiscriminator(input_shape=(crop_size, crop_size, self.color_depth)) self.D_B = module.ConvDiscriminator(input_shape=(crop_size, crop_size, self.color_depth)) self.d_loss_fn, self.g_loss_fn = gan.get_adversarial_losses_fn( self.adversarial_loss_mode) self.cycle_loss_fn = tf.losses.MeanAbsoluteError() self.identity_loss_fn = tf.losses.MeanAbsoluteError() self.G_lr_scheduler = module.LinearDecay( self.lr, self.epochs * self.len_dataset, self.epoch_decay * self.len_dataset) self.D_lr_scheduler = module.LinearDecay( self.lr, self.epochs * self.len_dataset, self.epoch_decay * self.len_dataset) self.G_optimizer = keras.optimizers.Adam( learning_rate=self.G_lr_scheduler, beta_1=self.beta_1) self.D_optimizer = keras.optimizers.Adam( learning_rate=self.D_lr_scheduler, beta_1=self.beta_1)
# setup the normalization function for discriminator if args.gradient_penalty_mode == 'none': d_norm = 'batch_norm' if args.gradient_penalty_mode in ['dragan', 'wgan-gp']: # cannot use batch normalization with gradient penalty # TODO(Lynn) # Layer normalization is more stable than instance normalization here, # but instance normalization works in other implementations. # Please tell me if you find out the cause. d_norm = 'layer_norm' # networks # Comment by K.C: # the following commands set the structure of a G model G = module.ConvGenerator(input_shape=(1, 1, args.z_dim), output_channels=shape[-1], n_upsamplings=n_G_upsamplings, name='G_%s' % args.dataset) D = module.ConvDiscriminator(input_shape=shape, n_downsamplings=n_D_downsamplings, norm=d_norm, name='D_%s' % args.dataset) py.mkdir('%s/summaries' %output_dir) keras.utils.plot_model(G,'%s/summaries/convGenerator.png' % output_dir, show_shapes=True) keras.utils.plot_model(D,'%s/summaries/convDiscriminator.png' % output_dir, show_shapes=True) G.summary() D.summary() # adversarial_loss_functions d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(args.adversarial_loss_mode) G_optimizer = keras.optimizers.Adam(learning_rate=args.lr, beta_1=args.beta_1) D_optimizer = keras.optimizers.Adam(learning_rate=args.lr, beta_1=args.beta_1) # ==============================================================================
G_downsamplings = args.n_downsamplings D_downsamplings = min(args.n_downsamplings + 1, 4) G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.dim, n_downsamplings=G_downsamplings, n_blocks=args.n_blocks, norm=args.norm) G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.dim, n_downsamplings=G_downsamplings, n_blocks=args.n_blocks, norm=args.norm) D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.dim, n_downsamplings=D_downsamplings, norm=args.norm) D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.dim, n_downsamplings=D_downsamplings, norm=args.norm) d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn( args.adversarial_loss_mode) cycle_loss_fn = tf.losses.MeanAbsoluteError() identity_loss_fn = tf.losses.MeanAbsoluteError() G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset)
A_B_dataset_test, _ = data.make_zip_dataset(A_img_paths_test, B_img_paths_test, args.batch_size, args.load_size, args.crop_size, training=False, repeat=True) # ============================================================================== # = models = # ============================================================================== G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3)) G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3)) D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3)) D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3)) d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn( args.adversarial_loss_mode) cycle_loss_fn = tf.losses.MeanAbsoluteError() identity_loss_fn = tf.losses.MeanAbsoluteError() G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) G_optimizer = keras.optimizers.Adam(learning_rate=G_lr_scheduler, beta_1=args.beta_1) D_optimizer = keras.optimizers.Adam(learning_rate=D_lr_scheduler, beta_1=args.beta_1)
repeat=False, shuffle=False) # ============================================================================== # = models = # ============================================================================== with tf.device('/device:GPU:%d' % GPU[1]): G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.kernels_num) G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.kernels_num) D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.kernels_num) D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.kernels_num) d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn( args.adversarial_loss_mode) cycle_loss_fn = tf.losses.MeanAbsoluteError() identity_loss_fn = tf.losses.MeanAbsoluteError() G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) G_optimizer = keras.optimizers.Adam(learning_rate=G_lr_scheduler, beta_1=args.beta_1) D_optimizer = keras.optimizers.Adam(learning_rate=D_lr_scheduler,
# ============================================================================== # = model = # ============================================================================== # setup the normalization function for discriminator if args.gradient_penalty_mode == 'none': d_norm = 'batch_norm' else: # cannot use batch normalization with gradient penalty d_norm = args.gradient_penalty_d_norm # networks G = module.ConvGenerator(args.z_dim, shape[-1], n_upsamplings=n_G_upsamplings).to(device) D = module.ConvDiscriminator(shape[-1], n_downsamplings=n_D_downsamplings, norm=d_norm).to(device) print(G) print(D) # adversarial_loss_functions d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn( args.adversarial_loss_mode) # optimizer G_optimizer = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta_1, 0.999)) D_optimizer = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta_1, 0.999))
# = models = # ============================================================================== G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3), n_downsamplings=args.G_n_downsamplings, n_blocks=args.G_n_residual_blocks, dim=args.G_n_filters, norm=args.norm_type) G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3), n_downsamplings=args.G_n_downsamplings, n_blocks=args.G_n_residual_blocks, dim=args.G_n_filters, norm=args.norm_type) D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3), n_downsamplings=args.D_n_downsamplings, dim=args.D_n_filters, norm=args.norm_type) D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3), n_downsamplings=args.D_n_downsamplings, dim=args.D_n_filters, norm=args.norm_type) d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn( args.adversarial_loss_mode) cycle_loss_fn = tf.losses.MeanAbsoluteError() identity_loss_fn = tf.losses.MeanAbsoluteError() G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset)
len_dataset = B_length #print('b list\n',B_list[0][0].squeeze().shape) A2B_pool = data.ItemPool(apool_size) #B2A_pool = data.ItemPool(a.pool_size) print('session starts \n') print('B', B_length) # ============================================================================== # = models = # ============================================================================== G = module.ResnetGenerator(input_shape=(acrop_size, acrop_size, 6)) D = module.ConvDiscriminator(input_shape=(acrop_size, acrop_size, 6)) #print('generator') #print(G.summary()) #print('discriminator') #print(D.summary()) d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn(aadversarial_loss_mode) cycle_loss_fn = tf.losses.MeanAbsoluteError() identity_loss_fn = tf.losses.MeanAbsoluteError() G_lr_scheduler = module.LinearDecay(alr, aepochs * len_dataset, aepoch_decay * len_dataset) D_lr_scheduler = module.LinearDecay(alr, aepochs * len_dataset, aepoch_decay * len_dataset) G_optimizer = keras.optimizers.Adam(learning_rate=G_lr_scheduler,