def construct_model(self, crop_size, load_size): self.A_B_dataset, len_dataset = data.make_zip_dataset( self.A_img_paths, self.B_img_paths, self.batch_size, load_size, crop_size, training=True, repeat=False, is_gray_scale=(self.color_depth == 1)) self.len_dataset = len_dataset self.A2B_pool = data.ItemPool(self.pool_size) self.B2A_pool = data.ItemPool(self.pool_size) A_img_paths_test = py.glob( py.join(self.datasets_dir, self.dataset, 'testA'), '*.{}'.format(self.image_ext)) B_img_paths_test = py.glob( py.join(self.datasets_dir, self.dataset, 'testB'), '*.{}'.format(self.image_ext)) A_B_dataset_test, _ = data.make_zip_dataset( A_img_paths_test, B_img_paths_test, self.batch_size, load_size, crop_size, training=False, repeat=True, is_gray_scale=(self.color_depth == 1)) self.test_iter = iter(A_B_dataset_test) self.G_A2B = module.ResnetGenerator(input_shape=(crop_size, crop_size, self.color_depth), output_channels=self.color_depth) self.G_B2A = module.ResnetGenerator(input_shape=(crop_size, crop_size, self.color_depth), output_channels=self.color_depth) self.D_A = module.ConvDiscriminator(input_shape=(crop_size, crop_size, self.color_depth)) self.D_B = module.ConvDiscriminator(input_shape=(crop_size, crop_size, self.color_depth)) self.d_loss_fn, self.g_loss_fn = gan.get_adversarial_losses_fn( self.adversarial_loss_mode) self.cycle_loss_fn = tf.losses.MeanAbsoluteError() self.identity_loss_fn = tf.losses.MeanAbsoluteError() self.G_lr_scheduler = module.LinearDecay( self.lr, self.epochs * self.len_dataset, self.epoch_decay * self.len_dataset) self.D_lr_scheduler = module.LinearDecay( self.lr, self.epochs * self.len_dataset, self.epoch_decay * self.len_dataset) self.G_optimizer = keras.optimizers.Adam( learning_rate=self.G_lr_scheduler, beta_1=self.beta_1) self.D_optimizer = keras.optimizers.Adam( learning_rate=self.D_lr_scheduler, beta_1=self.beta_1)
D_A = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.dim, n_downsamplings=D_downsamplings, norm=args.norm) D_B = module.ConvDiscriminator(input_shape=(args.crop_size, args.crop_size, 3), dim=args.dim, n_downsamplings=D_downsamplings, norm=args.norm) d_loss_fn, g_loss_fn = gan.get_adversarial_losses_fn( args.adversarial_loss_mode) cycle_loss_fn = tf.losses.MeanAbsoluteError() identity_loss_fn = tf.losses.MeanAbsoluteError() G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) G_optimizer = keras.optimizers.Adam(learning_rate=G_lr_scheduler, beta_1=args.beta_1) D_optimizer = keras.optimizers.Adam(learning_rate=D_lr_scheduler, beta_1=args.beta_1) # ============================================================================== # = train step = # ============================================================================== @tf.function def train_G(A, B): with tf.GradientTape() as t: