コード例 #1
0
def main(args):
    global_config.train = True
    global_config.autotune = True
    global_config.cudnn_deterministic = False
    print("Initializing models...")
    categories = len(args.dataset)
    generator = Generator(args.size, args.depth, args.levels, *args.channels,
                          categories)
    discriminator = Discriminator(args.levels, args.channels[1],
                                  args.channels[0], categories, args.depth,
                                  args.group)
    averaged_generator = generator.copy("copy")
    generator.to_device(args.device)
    discriminator.to_device(args.device)
    averaged_generator.to_device(args.device)
    optimizers = AdamSet(args.alpha, args.betas[0], args.betas[1])
    optimizers.setup(generator, discriminator)

    mkdirs(args.dest)
    dump_json(args, build_filepath(args.dest, "arguments", "json", args.force))
    if categories > 1:
        dataset = MulticategoryImageDataset(args.dataset, generator.resolution)
    else:
        dataset = ImageDataset(args.dataset, generator.resolution)
    if args.preload:
        with chainer_like_tqdm("dataset", len(dataset)) as bar:
            dataset.preload(lambda: bar.update())
    iterator = SerialIterator(dataset, args.batch, repeat=True, shuffle=True)
    updater = CustomUpdater(generator, averaged_generator, discriminator,
                            iterator, optimizers, categories > 1, args.ema,
                            args.lsgan)
    updater.enable_gradient_accumulation(4)
    updater.enable_style_mixing(args.mix)
    updater.enable_r1_regularization(args.gamma, args.r1)
    updater.enable_path_length_regularization(args.decay, args.weight, args.pl)

    pipeline = AugmentationPipeline()
    pipeline.to_device(args.device)
    updater.enable_adaptive_augumentation(pipeline)

    if args.snapshot is not None:
        updater.load_states(args.snapshot)
    trainer = CustomTrainer(updater, args.epoch, args.dest, args.force)
    trainer.hook_state_save(1000)
    trainer.hook_image_generation(1000, 32)
    trainer.enable_reports(500)
    trainer.enable_progress_bar(1)
    trainer.run()
    averaged_generator.save(
        build_filepath(args.dest, "generator", "hdf5", args.force))
    updater.save_states(
        build_filepath(args.dest, "snapshot", "hdf5", args.force))
コード例 #2
0
 def save_images(trainer):
     for i, n in range_batch(trainer.number, trainer.batch_size):
         z = trainer.updater.averaged_generator.generate_latents(n)
         c = trainer.updater.averaged_generator.generate_conditions(
             n) if trainer.conditional else None
         _, y = trainer.updater.averaged_generator(z, c)
         z.to_cpu()
         y.to_cpu()
         for j in range(n):
             filename = f"{trainer.iteration}-{i + j + 1}"
             np.save(
                 build_filepath(trainer.images_out, filename, "npy",
                                trainer.overwrite), z.array[j])
             save_image(
                 y.array[j],
                 build_filepath(trainer.images_out, filename, "png",
                                trainer.overwrite))
コード例 #3
0
 def enable_reports(self, interval):
     filename = os.path.basename(
         build_filepath(self.out, "report", "log", self.overwrite))
     log_report = LogReport(trigger=(interval, "iteration"),
                            filename=filename)
     entries = [
         "iteration", "loss (G)", "loss (D)", "penalty (G)", "penalty (D)",
         "overfitting"
     ]
     print_report = PrintReport(entries, log_report)
     entries = [
         "loss (G)", "loss (D)", "penalty (G)", "penalty (D)", "overfitting"
     ]
     filename = os.path.basename(
         build_filepath(self.out, "curves", "png", self.overwrite))
     plot_report = PlotReport(entries,
                              "iteration",
                              trigger=(interval, "iteration"),
                              filename=filename)
     self.extend(log_report)
     self.extend(print_report)
     self.extend(plot_report)
コード例 #4
0
          total=args.number,
          bar_format=bf,
          miniters=1,
          ascii=".#",
          ncols=70) as bar:
    for i, n in range_batch(args.number, args.batch):
        #mixing = mix > random()
        z = generator.generate_latents(n)
        c = generator.generate_conditions(n)
        #mix_z = generator.generate_latent(n) if mixing else None
        ws, y = generator(z, c)
        z.to_cpu()
        y.to_cpu()
        for j in range(n):
            filename = f"{i + j + 1}"
            np.save(build_filepath(args.dest, filename, "npy", args.force),
                    z.array[j])
            save_image(y.array[j],
                       build_filepath(args.dest, filename, "png", args.force))
            bar.update()
'''
# Generate images
c = 0
mean_w = None if args.psi is None else generator.calculate_mean_w()
while c < args.number:
	n = min(args.number - c, args.batch)
	z = generator.generate_latent(n, center=center, sd=args.sd)
	y = generator(z, args.stage, alpha=args.alpha, psi=args.psi, mean_w=mean_w)
	z.to_cpu()
	y.to_cpu()
	for i in range(n):
コード例 #5
0
#mix = gen.generate_latent(args.batch)
ws, i = generator(z)

y = discriminator(i)

gen_varstyle = {"fillcolor": "#5edbf1", "shape": "record", "style": "filled"}
gen_funstyle = {"fillcolor": "#ffa9e0", "shape": "record", "style": "filled"}
dis_varstyle = {"fillcolor": "#7a9fe6", "shape": "record", "style": "filled"}
dis_funstyle = {"fillcolor": "#fea21d", "shape": "record", "style": "filled"}

gen_graph = build_computational_graph([i],
                                      variable_style=gen_varstyle,
                                      function_style=gen_funstyle).dump()

i.unchain_backward()

dis_graph = build_computational_graph([y],
                                      variable_style=dis_varstyle,
                                      function_style=dis_funstyle).dump()

#print(f"D: {self.count_params()}")
#print(self.count_params())

mkdirs(args.dest)
gen_path = build_filepath(args.dest, "generator", "pdf", args.force)
graph_from_dot_data(gen_graph)[0].write_pdf(gen_path)
print(f"Saved: {gen_path}")
dis_path = build_filepath(args.dest, "discriminator", "pdf", args.force)
graph_from_dot_data(dis_graph)[0].write_pdf(dis_path)
print(f"Saved: {dis_path}")
コード例 #6
0
 def save_generator(trainer):
     filepath = build_filepath(trainer.states_out,
                               f"generator-{trainer.iteration}", "hdf5",
                               trainer.overwrite)
     trainer.updater.averaged_generator.save(filepath)
コード例 #7
0
 def save_snapshot(trainer):
     filepath = build_filepath(trainer.states_out,
                               f"snapshot-{trainer.iteration}", "hdf5",
                               trainer.overwrite)
     trainer.updater.save_states(filepath)