def train(self, args, ckpt_nmbr=None): # Initialize augmentor. augmentor = img_augm.Augmentor(crop_size=[self.options.image_size, self.options.image_size], vertical_flip_prb=0., hsv_augm_prb=1.0, hue_augm_shift=0.05, saturation_augm_shift=0.05, saturation_augm_scale=0.05, value_augm_shift=0.05, value_augm_scale=0.05, ) content_dataset_places = prepare_dataset.PlacesDataset(path_to_dataset=self.options.path_to_content_dataset) art_dataset = prepare_dataset.ArtDataset(path_to_art_dataset=self.options.path_to_art_dataset) # Initialize queue workers for both datasets. q_art = multiprocessing.Queue(maxsize=10) q_content = multiprocessing.Queue(maxsize=10) jobs = [] for i in range(5): p = multiprocessing.Process(target=content_dataset_places.initialize_batch_worker, args=(q_content, augmentor, self.batch_size, i)) p.start() jobs.append(p) p = multiprocessing.Process(target=art_dataset.initialize_batch_worker, args=(q_art, augmentor, self.batch_size, i)) p.start() jobs.append(p) print("Processes are started.") time.sleep(3) # Now initialize the graph init_op = tf.global_variables_initializer() self.sess.run(init_op) print("Start training.") if self.load(self.checkpoint_dir, ckpt_nmbr): print(" [*] Load SUCCESS") else: if self.load(self.checkpoint_long_dir, ckpt_nmbr): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") # Initial discriminator success rate. win_rate = args.discr_success_rate discr_success = args.discr_success_rate alpha = 0.05 for step in tqdm(range(self.initial_step, self.options.total_steps+1), initial=self.initial_step, total=self.options.total_steps): # Get batch from the queue with batches q, if the last is non-empty. while q_art.empty() or q_content.empty(): pass batch_art = q_art.get() batch_content = q_content.get() if discr_success >= win_rate: # Train generator _, summary_all, gener_acc_ = self.sess.run( [self.g_optim_step, self.summary_merged_all, self.gener_acc], feed_dict={ self.input_painting: normalize_arr_of_imgs(batch_art['image']), self.input_photo: normalize_arr_of_imgs(batch_content['image']), self.lr: self.options.lr }) discr_success = discr_success * (1. - alpha) + alpha * (1. - gener_acc_) else: # Train discriminator. _, summary_all, discr_acc_ = self.sess.run( [self.d_optim_step, self.summary_merged_all, self.discr_acc], feed_dict={ self.input_painting: normalize_arr_of_imgs(batch_art['image']), self.input_photo: normalize_arr_of_imgs(batch_content['image']), self.lr: self.options.lr }) discr_success = discr_success * (1. - alpha) + alpha * discr_acc_ self.writer.add_summary(summary_all, step * self.batch_size) if step % self.options.save_freq == 0 and step > self.initial_step: self.save(step) # And additionally save all checkpoints each 15000 steps. if step % 15000 == 0 and step > self.initial_step: self.save(step, is_long=True) if step % 500 == 0: output_paintings_, output_photos_= self.sess.run( [self.input_painting, self.output_photo], feed_dict={ self.input_painting: normalize_arr_of_imgs(batch_art['image']), self.input_photo: normalize_arr_of_imgs(batch_content['image']), self.lr: self.options.lr }) save_batch(input_painting_batch=batch_art['image'], input_photo_batch=batch_content['image'], output_painting_batch=denormalize_arr_of_imgs(output_paintings_), output_photo_batch=denormalize_arr_of_imgs(output_photos_), filepath='%s/step_%d.jpg' % (self.sample_dir, step)) print("Training is finished. Terminate jobs.") for p in jobs: p.join() p.terminate() print("Done.")
config['feature_loss_weight'], config['discr_success_rate'] )) trainer = ArtGAN(opts).cuda() initial_step = trainer.resume(checkpoint_directory, opts) if options.resume else 0 # prepare data augmentor = img_augm.Augmentor(crop_size=[opts.image_size, opts.image_size], vertical_flip_prb=0., hsv_augm_prb=1.0, hue_augm_shift=0.05, saturation_augm_shift=0.05, saturation_augm_scale=0.05, value_augm_shift=0.05, value_augm_scale=0.05, ) content_dataset_places = prepare_dataset.PlacesDataset(opts.content_data_path) art_dataset = prepare_dataset.ArtDataset(opts.art_data_path) q_art = multiprocessing.Queue(maxsize=10) q_content = multiprocessing.Queue(maxsize=10) jobs = [] for i in range(4): p = multiprocessing.Process(target=content_dataset_places.initialize_batch_worker, args=(q_content, augmentor, opts.batch_size, i)) p.start() jobs.append(p) p = multiprocessing.Process(target=art_dataset.initialize_batch_worker, args=(q_art, augmentor, opts.batch_size, i)) p.start() jobs.append(p) print("Processes are started.") time.sleep(3)
def train(self, args): # Initialize augmentor. augmentor = img_augm.Augmentor( crop_size=[self.options.image_size, self.options.image_size], hue_augm_shift=0.05, saturation_augm_shift=0.05, saturation_augm_scale=0.05, value_augm_shift=0.05, value_augm_scale=0.05, ) content_dataset_places = prepare_dataset.PlacesDataset( path_to_dataset=self.options.path_to_content_dataset) art_dataset = prepare_dataset.ArtDataset( path_to_art_dataset=self.options.path_to_art_dataset) # Now initialize the graph init_op = tf.compat.v1.global_variables_initializer() self.sess.run(init_op) print("Start training.") if self.load(self.checkpoint_dir): print(" [*] Load SUCCESS") else: if self.load(self.checkpoint_long_dir): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") # Initial discriminator success rate. win_rate = args.discr_success_rate discr_success = args.discr_success_rate alpha = 0.05 for step in tqdm(range(self.initial_step, self.options.total_steps + 1), initial=self.initial_step, total=self.options.total_steps): #print('step {}'.format(step)) batch_art = art_dataset.get_batch(augmentor=augmentor, batch_size=self.batch_size) batch_content = content_dataset_places.get_batch( augmentor=augmentor, batch_size=self.batch_size) if discr_success >= win_rate: # Train generator _, summary_all, gener_acc_ = self.sess.run( [ self.g_optim_step, self.summary_merged_all, self.gener_acc ], feed_dict={ self.input_painting: normalize_arr_of_imgs(batch_art['image']), self.input_photo: normalize_arr_of_imgs(batch_content['image']), self.lr: self.options.lr }) discr_success = discr_success * (1. - alpha) + alpha * ( 1. - gener_acc_) else: # Train discriminator. _, summary_all, discr_acc_ = self.sess.run( [ self.d_optim_step, self.summary_merged_all, self.discr_acc ], feed_dict={ self.input_painting: normalize_arr_of_imgs(batch_art['image']), self.input_photo: normalize_arr_of_imgs(batch_content['image']), self.lr: self.options.lr }) discr_success = discr_success * (1. - alpha) + alpha * discr_acc_ self.writer.add_summary(summary_all, step * self.batch_size) if step % self.options.save_freq == 0 and step > self.initial_step: self.save(step) # And additionally save all checkpoints each 15000 steps. if step % 15000 == 0 and step > self.initial_step: self.save(step, is_long=True) if step % 500 == 0: output_paintings_, output_photos_ = self.sess.run( [self.input_painting, self.output_photo], feed_dict={ self.input_painting: normalize_arr_of_imgs(batch_art['image']), self.input_photo: normalize_arr_of_imgs(batch_content['image']), self.lr: self.options.lr }) save_batch( input_painting_batch=batch_art['image'], input_photo_batch=batch_content['image'], output_painting_batch=denormalize_arr_of_imgs( output_paintings_), output_photo_batch=denormalize_arr_of_imgs(output_photos_), filepath='%s/step_%d.jpg' % (self.sample_dir, step)) print("Done.")