Пример #1
0
    def train(self, args, ckpt_nmbr=None):
        # Initialize augmentor.
        augmentor = img_augm.Augmentor(crop_size=[self.options.image_size, self.options.image_size],
                                       vertical_flip_prb=0.,
                                       hsv_augm_prb=1.0,
                                       hue_augm_shift=0.05,
                                       saturation_augm_shift=0.05, saturation_augm_scale=0.05,
                                       value_augm_shift=0.05, value_augm_scale=0.05, )
        content_dataset_places = prepare_dataset.PlacesDataset(path_to_dataset=self.options.path_to_content_dataset)
        art_dataset = prepare_dataset.ArtDataset(path_to_art_dataset=self.options.path_to_art_dataset)


        # Initialize queue workers for both datasets.
        q_art = multiprocessing.Queue(maxsize=10)
        q_content = multiprocessing.Queue(maxsize=10)
        jobs = []
        for i in range(5):
            p = multiprocessing.Process(target=content_dataset_places.initialize_batch_worker,
                                        args=(q_content, augmentor, self.batch_size, i))
            p.start()
            jobs.append(p)

            p = multiprocessing.Process(target=art_dataset.initialize_batch_worker,
                                        args=(q_art, augmentor, self.batch_size, i))
            p.start()
            jobs.append(p)
        print("Processes are started.")
        time.sleep(3)

        # Now initialize the graph
        init_op = tf.global_variables_initializer()
        self.sess.run(init_op)
        print("Start training.")

        if self.load(self.checkpoint_dir, ckpt_nmbr):
            print(" [*] Load SUCCESS")
        else:
            if self.load(self.checkpoint_long_dir, ckpt_nmbr):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")

        # Initial discriminator success rate.
        win_rate = args.discr_success_rate
        discr_success = args.discr_success_rate
        alpha = 0.05

        for step in tqdm(range(self.initial_step, self.options.total_steps+1),
                         initial=self.initial_step,
                         total=self.options.total_steps):
            # Get batch from the queue with batches q, if the last is non-empty.
            while q_art.empty() or q_content.empty():
                pass
            batch_art = q_art.get()
            batch_content = q_content.get()

            if discr_success >= win_rate:
                # Train generator
                _, summary_all, gener_acc_ = self.sess.run(
                    [self.g_optim_step, self.summary_merged_all, self.gener_acc],
                    feed_dict={
                        self.input_painting: normalize_arr_of_imgs(batch_art['image']),
                        self.input_photo: normalize_arr_of_imgs(batch_content['image']),
                        self.lr: self.options.lr
                    })
                discr_success = discr_success * (1. - alpha) + alpha * (1. - gener_acc_)
            else:
                # Train discriminator.
                _, summary_all, discr_acc_ = self.sess.run(
                    [self.d_optim_step, self.summary_merged_all, self.discr_acc],
                    feed_dict={
                        self.input_painting: normalize_arr_of_imgs(batch_art['image']),
                        self.input_photo: normalize_arr_of_imgs(batch_content['image']),
                        self.lr: self.options.lr
                    })

                discr_success = discr_success * (1. - alpha) + alpha * discr_acc_
            self.writer.add_summary(summary_all, step * self.batch_size)

            if step % self.options.save_freq == 0 and step > self.initial_step:
                self.save(step)

            # And additionally save all checkpoints each 15000 steps.
            if step % 15000 == 0 and step > self.initial_step:
                self.save(step, is_long=True)

            if step % 500 == 0:
                output_paintings_, output_photos_= self.sess.run(
                    [self.input_painting, self.output_photo],
                    feed_dict={
                        self.input_painting: normalize_arr_of_imgs(batch_art['image']),
                        self.input_photo: normalize_arr_of_imgs(batch_content['image']),
                        self.lr: self.options.lr
                    })

                save_batch(input_painting_batch=batch_art['image'],
                           input_photo_batch=batch_content['image'],
                           output_painting_batch=denormalize_arr_of_imgs(output_paintings_),
                           output_photo_batch=denormalize_arr_of_imgs(output_photos_),
                           filepath='%s/step_%d.jpg' % (self.sample_dir, step))
        print("Training is finished. Terminate jobs.")
        for p in jobs:
            p.join()
            p.terminate()

        print("Done.")
                            config['feature_loss_weight'],
                            config['discr_success_rate']
                            ))
    trainer = ArtGAN(opts).cuda()
    initial_step = trainer.resume(checkpoint_directory, opts) if options.resume else 0
    # prepare data
    augmentor = img_augm.Augmentor(crop_size=[opts.image_size, opts.image_size],
                                    vertical_flip_prb=0.,
                                    hsv_augm_prb=1.0,
                                    hue_augm_shift=0.05,
                                    saturation_augm_shift=0.05, 
                                    saturation_augm_scale=0.05,
                                    value_augm_shift=0.05, 
                                    value_augm_scale=0.05, )
    content_dataset_places = prepare_dataset.PlacesDataset(opts.content_data_path)
    art_dataset = prepare_dataset.ArtDataset(opts.art_data_path)
    q_art = multiprocessing.Queue(maxsize=10)
    q_content = multiprocessing.Queue(maxsize=10)
    jobs = []
    for i in range(4):
        p = multiprocessing.Process(target=content_dataset_places.initialize_batch_worker,
                                    args=(q_content, augmentor, opts.batch_size, i))
        p.start()
        jobs.append(p)

        p = multiprocessing.Process(target=art_dataset.initialize_batch_worker,
                                    args=(q_art, augmentor, opts.batch_size, i))
        p.start()
        jobs.append(p)
    print("Processes are started.")
    time.sleep(3)
Пример #3
0
    def train(self, args):
        # Initialize augmentor.
        augmentor = img_augm.Augmentor(
            crop_size=[self.options.image_size, self.options.image_size],
            hue_augm_shift=0.05,
            saturation_augm_shift=0.05,
            saturation_augm_scale=0.05,
            value_augm_shift=0.05,
            value_augm_scale=0.05,
        )

        content_dataset_places = prepare_dataset.PlacesDataset(
            path_to_dataset=self.options.path_to_content_dataset)
        art_dataset = prepare_dataset.ArtDataset(
            path_to_art_dataset=self.options.path_to_art_dataset)

        # Now initialize the graph
        init_op = tf.compat.v1.global_variables_initializer()
        self.sess.run(init_op)
        print("Start training.")
        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            if self.load(self.checkpoint_long_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")

        # Initial discriminator success rate.
        win_rate = args.discr_success_rate
        discr_success = args.discr_success_rate
        alpha = 0.05

        for step in tqdm(range(self.initial_step,
                               self.options.total_steps + 1),
                         initial=self.initial_step,
                         total=self.options.total_steps):
            #print('step {}'.format(step))
            batch_art = art_dataset.get_batch(augmentor=augmentor,
                                              batch_size=self.batch_size)
            batch_content = content_dataset_places.get_batch(
                augmentor=augmentor, batch_size=self.batch_size)
            if discr_success >= win_rate:
                # Train generator
                _, summary_all, gener_acc_ = self.sess.run(
                    [
                        self.g_optim_step, self.summary_merged_all,
                        self.gener_acc
                    ],
                    feed_dict={
                        self.input_painting:
                        normalize_arr_of_imgs(batch_art['image']),
                        self.input_photo:
                        normalize_arr_of_imgs(batch_content['image']),
                        self.lr:
                        self.options.lr
                    })
                discr_success = discr_success * (1. - alpha) + alpha * (
                    1. - gener_acc_)
            else:
                # Train discriminator.
                _, summary_all, discr_acc_ = self.sess.run(
                    [
                        self.d_optim_step, self.summary_merged_all,
                        self.discr_acc
                    ],
                    feed_dict={
                        self.input_painting:
                        normalize_arr_of_imgs(batch_art['image']),
                        self.input_photo:
                        normalize_arr_of_imgs(batch_content['image']),
                        self.lr:
                        self.options.lr
                    })

                discr_success = discr_success * (1. -
                                                 alpha) + alpha * discr_acc_
            self.writer.add_summary(summary_all, step * self.batch_size)

            if step % self.options.save_freq == 0 and step > self.initial_step:
                self.save(step)

            # And additionally save all checkpoints each 15000 steps.
            if step % 15000 == 0 and step > self.initial_step:
                self.save(step, is_long=True)

            if step % 500 == 0:
                output_paintings_, output_photos_ = self.sess.run(
                    [self.input_painting, self.output_photo],
                    feed_dict={
                        self.input_painting:
                        normalize_arr_of_imgs(batch_art['image']),
                        self.input_photo:
                        normalize_arr_of_imgs(batch_content['image']),
                        self.lr:
                        self.options.lr
                    })

                save_batch(
                    input_painting_batch=batch_art['image'],
                    input_photo_batch=batch_content['image'],
                    output_painting_batch=denormalize_arr_of_imgs(
                        output_paintings_),
                    output_photo_batch=denormalize_arr_of_imgs(output_photos_),
                    filepath='%s/step_%d.jpg' % (self.sample_dir, step))

        print("Done.")