def main(args): global_config.train = True global_config.autotune = True global_config.cudnn_deterministic = False print("Initializing models...") categories = len(args.dataset) generator = Generator(args.size, args.depth, args.levels, *args.channels, categories) discriminator = Discriminator(args.levels, args.channels[1], args.channels[0], categories, args.depth, args.group) averaged_generator = generator.copy("copy") generator.to_device(args.device) discriminator.to_device(args.device) averaged_generator.to_device(args.device) optimizers = AdamSet(args.alpha, args.betas[0], args.betas[1]) optimizers.setup(generator, discriminator) mkdirs(args.dest) dump_json(args, build_filepath(args.dest, "arguments", "json", args.force)) if categories > 1: dataset = MulticategoryImageDataset(args.dataset, generator.resolution) else: dataset = ImageDataset(args.dataset, generator.resolution) if args.preload: with chainer_like_tqdm("dataset", len(dataset)) as bar: dataset.preload(lambda: bar.update()) iterator = SerialIterator(dataset, args.batch, repeat=True, shuffle=True) updater = CustomUpdater(generator, averaged_generator, discriminator, iterator, optimizers, categories > 1, args.ema, args.lsgan) updater.enable_gradient_accumulation(4) updater.enable_style_mixing(args.mix) updater.enable_r1_regularization(args.gamma, args.r1) updater.enable_path_length_regularization(args.decay, args.weight, args.pl) pipeline = AugmentationPipeline() pipeline.to_device(args.device) updater.enable_adaptive_augumentation(pipeline) if args.snapshot is not None: updater.load_states(args.snapshot) trainer = CustomTrainer(updater, args.epoch, args.dest, args.force) trainer.hook_state_save(1000) trainer.hook_image_generation(1000, 32) trainer.enable_reports(500) trainer.enable_progress_bar(1) trainer.run() averaged_generator.save( build_filepath(args.dest, "generator", "hdf5", args.force)) updater.save_states( build_filepath(args.dest, "snapshot", "hdf5", args.force))
def save_images(trainer): for i, n in range_batch(trainer.number, trainer.batch_size): z = trainer.updater.averaged_generator.generate_latents(n) c = trainer.updater.averaged_generator.generate_conditions( n) if trainer.conditional else None _, y = trainer.updater.averaged_generator(z, c) z.to_cpu() y.to_cpu() for j in range(n): filename = f"{trainer.iteration}-{i + j + 1}" np.save( build_filepath(trainer.images_out, filename, "npy", trainer.overwrite), z.array[j]) save_image( y.array[j], build_filepath(trainer.images_out, filename, "png", trainer.overwrite))
def enable_reports(self, interval): filename = os.path.basename( build_filepath(self.out, "report", "log", self.overwrite)) log_report = LogReport(trigger=(interval, "iteration"), filename=filename) entries = [ "iteration", "loss (G)", "loss (D)", "penalty (G)", "penalty (D)", "overfitting" ] print_report = PrintReport(entries, log_report) entries = [ "loss (G)", "loss (D)", "penalty (G)", "penalty (D)", "overfitting" ] filename = os.path.basename( build_filepath(self.out, "curves", "png", self.overwrite)) plot_report = PlotReport(entries, "iteration", trigger=(interval, "iteration"), filename=filename) self.extend(log_report) self.extend(print_report) self.extend(plot_report)
total=args.number, bar_format=bf, miniters=1, ascii=".#", ncols=70) as bar: for i, n in range_batch(args.number, args.batch): #mixing = mix > random() z = generator.generate_latents(n) c = generator.generate_conditions(n) #mix_z = generator.generate_latent(n) if mixing else None ws, y = generator(z, c) z.to_cpu() y.to_cpu() for j in range(n): filename = f"{i + j + 1}" np.save(build_filepath(args.dest, filename, "npy", args.force), z.array[j]) save_image(y.array[j], build_filepath(args.dest, filename, "png", args.force)) bar.update() ''' # Generate images c = 0 mean_w = None if args.psi is None else generator.calculate_mean_w() while c < args.number: n = min(args.number - c, args.batch) z = generator.generate_latent(n, center=center, sd=args.sd) y = generator(z, args.stage, alpha=args.alpha, psi=args.psi, mean_w=mean_w) z.to_cpu() y.to_cpu() for i in range(n):
#mix = gen.generate_latent(args.batch) ws, i = generator(z) y = discriminator(i) gen_varstyle = {"fillcolor": "#5edbf1", "shape": "record", "style": "filled"} gen_funstyle = {"fillcolor": "#ffa9e0", "shape": "record", "style": "filled"} dis_varstyle = {"fillcolor": "#7a9fe6", "shape": "record", "style": "filled"} dis_funstyle = {"fillcolor": "#fea21d", "shape": "record", "style": "filled"} gen_graph = build_computational_graph([i], variable_style=gen_varstyle, function_style=gen_funstyle).dump() i.unchain_backward() dis_graph = build_computational_graph([y], variable_style=dis_varstyle, function_style=dis_funstyle).dump() #print(f"D: {self.count_params()}") #print(self.count_params()) mkdirs(args.dest) gen_path = build_filepath(args.dest, "generator", "pdf", args.force) graph_from_dot_data(gen_graph)[0].write_pdf(gen_path) print(f"Saved: {gen_path}") dis_path = build_filepath(args.dest, "discriminator", "pdf", args.force) graph_from_dot_data(dis_graph)[0].write_pdf(dis_path) print(f"Saved: {dis_path}")
def save_generator(trainer): filepath = build_filepath(trainer.states_out, f"generator-{trainer.iteration}", "hdf5", trainer.overwrite) trainer.updater.averaged_generator.save(filepath)
def save_snapshot(trainer): filepath = build_filepath(trainer.states_out, f"snapshot-{trainer.iteration}", "hdf5", trainer.overwrite) trainer.updater.save_states(filepath)