def train(args): """Train a MNIST classifier.""" # Setup train and val data _xform = xforms.Compose([xforms.Resize([32, 32]), xforms.ToTensor()]) data = MNIST("data/mnist", train=True, download=True, transform=_xform) # Initialize asynchronous dataloaders loader = DataLoader(data, batch_size=args.bs, num_workers=2) # Instantiate the models gen = Generator() discrim = Discriminator() gen.apply(weights_init_normal) discrim.apply(weights_init_normal) # Checkpointer to save/recall model parameters checkpointer_gen = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=gen, prefix="gen_") checkpointer_discrim = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=discrim, prefix="discrim_") # resume from a previous checkpoint, if any checkpointer_gen.load_latest() checkpointer_discrim.load_latest() # Setup a training interface for the model interface = MNISTInterface(gen, discrim, lr=args.lr) # Create a training looper with the interface we defined trainer = ttools.Trainer(interface) # Adds several callbacks, that will be called by the trainer -------------- # A periodic checkpointing operation trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_gen)) trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_discrim)) # A simple progress bar trainer.add_callback(ttools.callbacks.ProgressBarCallback( keys=["loss_g", "loss_d", "loss"])) # A volatile logging using visdom trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( keys=["loss_g", "loss_d", "loss"], port=8080, env="mnist_demo")) # Image trainer.add_callback(VisdomImageCallback(port=8080, env="mnist_demo")) # ------------------------------------------------------------------------- # Start the training LOG.info("Training started, press Ctrl-C to interrupt.") trainer.train(loader, num_epochs=args.epochs)
def main(args): data = datasets.ShapenetDataset(args.data, args.canvas_size) dataloader = DataLoader(data, batch_size=args.bs, num_workers=args.num_worker_threads, worker_init_fn=_worker_init_fn, shuffle=True, drop_last=True) LOG.info(data) val_data = datasets.ShapenetDataset(args.data, args.canvas_size, val=True) val_dataloader = DataLoader(val_data) model = PrimsModel(output_dim=(11 if args.rounded else 10)*args.n_primitives) checkpointer = ttools.Checkpointer(args.checkpoint_dir, model) checkpointer.load_latest() interface = VectorizerInterface(model, args.lr, args.n_primitives, args.canvas_size, args.w_surface, args.w_alignment, args.csg, args.rounded, cuda=args.cuda) keys = ['loss', 'surfaceloss', 'alignmentloss'] writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries', datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')), flush_secs=1) val_writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries', datetime.datetime.now().strftime('val-%m%d%y-%H%M%S')), flush_secs=1) trainer = ttools.Trainer(interface) trainer.add_callback(ttools.callbacks.TensorBoardLoggingCallback(keys=keys, writer=writer, val_writer=val_writer, frequency=5)) trainer.add_callback(ttools.callbacks.ProgressBarCallback(keys=keys)) trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer, max_files=1)) trainer.train(dataloader, num_epochs=args.num_epochs, val_dataloader=val_dataloader)
def interpolate(args): chkpt = VAE_OUTPUT if args.conditional: chkpt += "_conditional" if args.fc: chkpt += "_fc" meta = ttools.Checkpointer.load_meta(chkpt, prefix="g_") if meta is None: LOG.info("No metadata in checkpoint (or no checkpoint), aborting.") return model = VectorMNISTVAE(imsize=128, **meta) checkpointer = ttools.Checkpointer(chkpt, model, prefix="g_") checkpointer.load_latest() model.eval() # Sample some latent vectors bs = 10 z = th.randn(bs, model.zdim) label = None label = th.arange(0, 10) animation = [] nframes = 60 with th.no_grad(): for idx, _z in enumerate(z): if idx == 0: # skip first continue _z0 = z[idx-1].unsqueeze(0).repeat(nframes, 1) _z = _z.unsqueeze(0).repeat(nframes, 1) if args.conditional: _label = label[idx].unsqueeze(0).repeat(nframes) else: _label = None # interp weights alpha = th.linspace(0, 1, nframes).view(nframes, 1) batch = alpha*_z + (1.0 - alpha)*_z0 images, aux = model.decode(batch, label=_label) images += 1.0 images /= 2.0 animation.append(images) anim_dir = os.path.join(chkpt, "interpolation") os.makedirs(anim_dir, exist_ok=True) animation = th.cat(animation, 0) for idx, frame in enumerate(animation): frame = frame.squeeze() frame = th.clamp(frame, 0, 1).cpu().numpy() path = os.path.join(anim_dir, "frame%03d.png" % idx) pydiffvg.imwrite(frame, path, gamma=2.2) LOG.info("Results saved to %s", anim_dir)
def train(args): """Train a MNIST classifier.""" # Setup train and val data _xform = xforms.Compose([xforms.Resize([32, 32]), xforms.ToTensor()]) data = MNIST("data/mnist", train=True, download=True, transform=_xform) val_data = MNIST("data/mnist", train=False, transform=_xform) # Initialize asynchronous dataloaders loader = DataLoader(data, batch_size=args.bs, num_workers=2) val_loader = DataLoader(val_data, batch_size=16, num_workers=1) # Instantiate a model model = MNISTClassifier() # Checkpointer to save/recall model parameters checkpointer = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=model, prefix="classifier_") # resume from a previous checkpoint, if any checkpointer.load_latest() # Setup a training interface for the model if th.cuda.is_available(): device = th.device("cuda") else: device = th.device("cpu") interface = MNISTInterface(model, device, lr=args.lr) # Create a training looper with the interface we defined trainer = ttools.Trainer(interface) # Adds several callbacks, that will be called by the trainer -------------- # A periodic checkpointing operation LOG.info("This demo uses a Visdom to display the loss and accuracy, make sure you have a visdom server running! ('make visdom_server')") trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer)) # A simple progress bar trainer.add_callback(ttools.callbacks.ProgressBarCallback( keys=["loss", "accuracy"], val_keys=["loss", "accuracy"])) # A volatile logging using visdom trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( keys=["loss", "accuracy"], val_keys=["loss", "accuracy"], port=8080, env="mnist_demo")) # ------------------------------------------------------------------------- # Start the training LOG.info("Training started, press Ctrl-C to interrupt.") trainer.train(loader, num_epochs=args.epochs, val_dataloader=val_loader)
def main(args): device = "cuda" if th.cuda.is_available() and args.cuda else "cpu" model = CurvesModel(sum(templates.topology)) model.to(device) model.eval() checkpointer = ttools.Checkpointer(f'models/{args.model}', model) extras, _ = checkpointer.load_latest() if extras is not None: print(f"Loaded checkpoint (epoch {extras['epoch']})") else: print("Unable to load checkpoint") im = to_tensor(Image.open(args.image).convert('L').resize( (128, 128))).to(device) z = th.zeros(len(string.ascii_uppercase)).scatter_( 0, th.tensor(string.ascii_uppercase.index(args.letter)), 1).to(device) print(f"Processing image {args.image} (letter {args.letter})") curves = model(im[None], z[None])['curves'][0].detach().cpu() surface = cairo.PDFSurface(args.out, 128, 128) ctx = cairo.Context(surface) ctx.scale(128, 128) ctx.rectangle(0, 0, 1, 1) ctx.set_source_rgb(1, 1, 1) ctx.fill() ctx.save() im = cairo.ImageSurface.create_from_png(args.image) ctx.scale(1 / 128, 1 / 128) ctx.set_source_surface(im) ctx.paint() ctx.restore() draw_curves(curves, templates.n_loops[args.letter], ctx) surface.finish() print(f"Output saved to {args.out}")
def main(args): data = datasets.FontsDataset(args.data, args.chamfer, args.n_samples_per_curve) dataloader = DataLoader(data, batch_size=args.bs, num_workers=args.num_worker_threads, worker_init_fn=_worker_init_fn, shuffle=True, drop_last=True) LOG.info(data) val_data = datasets.FontsDataset(args.data, args.chamfer, args.n_samples_per_curve, val=True) val_dataloader = DataLoader(val_data) model = CurvesModel(n_curves=sum(templates.topology)) checkpointer = ttools.Checkpointer(args.checkpoint_dir, model) checkpointer.load_latest() interface = VectorizerInterface(model, args.simple_templates, args.lr, args.max_stroke, args.canvas_size, args.chamfer, args.n_samples_per_curve, args.w_surface, args.w_template, args.w_alignment, cuda=args.cuda) keys = ['loss', 'chamferloss', 'templateloss'] if args.chamfer \ else ['loss', 'surfaceloss', 'alignmentloss', 'templateloss'] writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries', datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')), flush_secs=1) val_writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries', datetime.datetime.now().strftime('val-%m%d%y-%H%M%S')), flush_secs=1) trainer = ttools.Trainer(interface) trainer.add_callback(ttools.callbacks.TensorBoardLoggingCallback(keys=keys, writer=writer, val_writer=val_writer, frequency=5)) trainer.add_callback(callbacks.InputImageCallback(writer=writer, val_writer=val_writer, frequency=100)) trainer.add_callback(callbacks.CurvesCallback(writer=writer, val_writer=val_writer, frequency=100)) if not args.chamfer: trainer.add_callback(callbacks.RenderingCallback(writer=writer, val_writer=val_writer, frequency=100)) trainer.add_callback(ttools.callbacks.ProgressBarCallback(keys=keys)) trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer, max_files=1)) trainer.train(dataloader, num_epochs=args.num_epochs, val_dataloader=val_dataloader)
def train(args): th.manual_seed(0) np.random.seed(0) dataset = data.QuickDrawDataset(args.dataset) dataloader = DataLoader(dataset, batch_size=args.bs, num_workers=4, shuffle=True, pin_memory=False) val_dataset = [s for idx, s in enumerate(dataset) if idx < 8] val_dataloader = DataLoader(val_dataset, batch_size=8, num_workers=4, shuffle=False, pin_memory=False) model_params = { "zdim": args.zdim, "num_gaussians": args.num_gaussians, "encoder_dim": args.encoder_dim, "decoder_dim": args.decoder_dim, } model = SketchRNN(**model_params) model.train() device = "cpu" if th.cuda.is_available(): device = "cuda" LOG.info("Using CUDA") interface = Interface(model, lr=args.lr, lr_decay=args.lr_decay, kl_decay=args.kl_decay, kl_weight=args.kl_weight, sampling_temperature=args.sampling_temperature, device=device) chkpt = OUTPUT_BASELINE env_name = "sketch_rnn" # Resume from checkpoint, if any checkpointer = ttools.Checkpointer(chkpt, model, meta=model_params, optimizers=interface.optimizers(), schedulers=interface.schedulers) extras, meta = checkpointer.load_latest() epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0 if meta is not None and meta != model_params: LOG.info( "Checkpoint's metaparams differ " "from CLI, aborting: %s and %s", meta, model_params) trainer = ttools.Trainer(interface) # Add callbacks losses = ["loss", "kl_loss", "recons_loss"] training_debug = ["lr", "kl_weight"] trainer.add_callback( ttools.callbacks.ProgressBarCallback(keys=losses, val_keys=None)) trainer.add_callback( ttools.callbacks.VisdomLoggingCallback(keys=losses, val_keys=None, env=env_name, port=args.port)) trainer.add_callback( ttools.callbacks.VisdomLoggingCallback(keys=training_debug, smoothing=0, val_keys=None, env=env_name, port=args.port)) trainer.add_callback( ttools.callbacks.CheckpointingCallback(checkpointer, max_files=2, interval=600, max_epochs=10)) trainer.add_callback( ttools.callbacks.LRSchedulerCallback(interface.schedulers)) trainer.add_callback( SketchRNNCallback(env=env_name, win="samples", port=args.port, frequency=args.freq)) # Start training trainer.train(dataloader, starting_epoch=epoch, val_dataloader=val_dataloader, num_epochs=args.num_epochs)
def generate_samples(args): chkpt = VAE_OUTPUT if args.conditional: chkpt += "_conditional" if args.fc: chkpt += "_fc" meta = ttools.Checkpointer.load_meta(chkpt, prefix="g_") if meta is None: LOG.info("No metadata in checkpoint (or no checkpoint), aborting.") return model = VectorMNISTVAE(**meta) checkpointer = ttools.Checkpointer(chkpt, model, prefix="g_") checkpointer.load_latest() model.eval() # Sample some latent vectors n = 8 bs = n*n z = th.randn(bs, model.zdim) imsize = 28 dataset = Dataset(args.data_dir, imsize) dataloader = DataLoader(dataset, batch_size=bs, num_workers=1, shuffle=True) for batch in dataloader: ref, label = batch break autoencode = True if autoencode: LOG.info("Sampling with auto-encoder code") if not args.conditional: label = None mu, logvar = model.encode(ref, label) z = model.reparameterize(mu, logvar) else: label = None if args.conditional: label = th.clamp(th.rand(bs)*10, 0, 9).long() if args.digit is not None: label[:] = args.digit with th.no_grad(): images, aux = model.decode(z, label=label) scenes = aux["scenes"] images += 1.0 images /= 2.0 h = w = model.imsize images = images.view(n, n, h, w).permute(0, 2, 1, 3) images = images.contiguous().view(n*h, n*w) images = th.clamp(images, 0, 1).cpu().numpy() path = os.path.join(chkpt, "samples.png") pydiffvg.imwrite(images, path, gamma=2.2) if autoencode: ref += 1.0 ref /= 2.0 ref = ref.view(n, n, h, w).permute(0, 2, 1, 3) ref = ref.contiguous().view(n*h, n*w) ref = th.clamp(ref, 0, 1).cpu().numpy() path = os.path.join(chkpt, "ref.png") pydiffvg.imwrite(ref, path, gamma=2.2) # merge scenes all_shapes = [] all_shape_groups = [] cur_id = 0 for idx, s in enumerate(scenes): shapes, shape_groups, _ = s # width, height = sizes # Shift digit on canvas center_x = idx % n center_y = idx // n for shape in shapes: shape.points[:, 0] += center_x * model.imsize shape.points[:, 1] += center_y * model.imsize all_shapes.append(shape) for grp in shape_groups: grp.shape_ids[:] = cur_id cur_id += 1 all_shape_groups.append(grp) LOG.info("Generated %d shapes", len(all_shapes)) fname = os.path.join(chkpt, "digits.svg") pydiffvg.save_svg(fname, n*model.imsize, n*model.imsize, all_shapes, all_shape_groups, use_gamma=False) LOG.info("Results saved to %s", chkpt)
def train(args): th.manual_seed(0) np.random.seed(0) pydiffvg.set_use_gpu(args.cuda) # Initialize datasets imsize = 28 dataset = Dataset(args.data_dir, imsize) dataloader = DataLoader(dataset, batch_size=args.bs, num_workers=4, shuffle=True) if args.generator in ["vae", "ae"]: LOG.info("Vector config:\n samples %d\n" " paths: %d\n segments: %d\n" " zdim: %d\n" " conditional: %d\n" " fc: %d\n", args.samples, args.paths, args.segments, args.zdim, args.conditional, args.fc) model_params = dict(samples=args.samples, paths=args.paths, segments=args.segments, conditional=args.conditional, zdim=args.zdim, fc=args.fc) if args.generator == "vae": model = VectorMNISTVAE(variational=True, **model_params) chkpt = VAE_OUTPUT name = "mnist_vae" elif args.generator == "ae": model = VectorMNISTVAE(variational=False, **model_params) chkpt = AE_OUTPUT name = "mnist_ae" else: raise ValueError("unknown generator") if args.conditional: name += "_conditional" chkpt += "_conditional" if args.fc: name += "_fc" chkpt += "_fc" # Resume from checkpoint, if any checkpointer = ttools.Checkpointer( chkpt, model, meta=model_params, prefix="g_") extras, meta = checkpointer.load_latest() if meta is not None and meta != model_params: LOG.info("Checkpoint's metaparams differ from CLI, aborting: %s and %s", meta, model_params) # Hook interface if args.generator in ["vae", "ae"]: variational = args.generator == "vae" if variational: LOG.info("Using a VAE") else: LOG.info("Using an AE") interface = VAEInterface(model, lr=args.lr, cuda=args.cuda, variational=variational, w_kld=args.kld_weight) trainer = ttools.Trainer(interface) # Add callbacks keys = ["loss_g", "loss_d"] if args.generator == "vae": keys = ["kld", "data_loss", "loss"] elif args.generator == "ae": keys = ["data_loss", "loss"] port = 8097 trainer.add_callback(ttools.callbacks.ProgressBarCallback( keys=keys, val_keys=keys)) trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( keys=keys, val_keys=keys, env=name, port=port)) trainer.add_callback(MNISTCallback( env=name, win="samples", port=port, frequency=args.freq)) trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer, max_files=2, interval=600, max_epochs=50)) # Start training trainer.train(dataloader, num_epochs=args.num_epochs)
def main(args): start = time.time() if not os.path.exists(args.input): raise ValueError("input {} does not exist".format(args.input)) data_root = os.path.abspath(args.input) name = os.path.basename(data_root) tmpdir = tempfile.mkdtemp() os.symlink(data_root, os.path.join(tmpdir, name)) LOG.info("Loading model {}".format(args.checkpoint)) meta_params = ttools.Checkpointer.load_meta(args.checkpoint) LOG.info("Setting up dataloader") data_params = meta_params["data_params"] if args.spp: data_params["spp"] = args.spp data = sbmc.FullImagesDataset(tmpdir, **data_params) dataloader = DataLoader(data, batch_size=1, shuffle=False, num_workers=0) LOG.info("Denoising input with {} spp".format(data_params["spp"])) kpcn_mode = meta_params["kpcn_mode"] if kpcn_mode: LOG.info("Using [Bako2017] denoiser.") model = sbmc.KPCN(data.num_features) else: model = sbmc.Multisteps(data.num_features, data.num_global_features) model.train(False) device = "cpu" cuda = th.cuda.is_available() if cuda: LOG.info("Using CUDA") model.cuda() device = "cuda" checkpointer = ttools.Checkpointer(args.checkpoint, model, None) extras, meta = checkpointer.load_latest() LOG.info("Loading latest checkpoint {}".format( "failed" if meta is None else "success")) elapsed = (time.time() - start) * 1000 LOG.info("setup time {:.1f} ms".format(elapsed)) LOG.info("starting the denoiser") for scene_id, batch in enumerate(dataloader): for k in batch.keys(): batch[k] = batch[k].to(device) scene = os.path.basename(data.scenes[scene_id]) LOG.info(" scene {}".format(scene)) tile_sz = args.tile_size tile_pad = args.tile_pad batch_parts = _split_tiles(batch, max_sz=tile_sz, pad=tile_pad) out_radiance = th.zeros_like(batch["low_spp"]) if cuda: th.cuda.synchronize() start = time.time() for part, start_y, end_y, start_x, end_x, pad_ in batch_parts: with th.no_grad(): out_ = model(part) out_ = _pad(part, out_["radiance"], kpcn_mode) out_ = out_[..., pad_[0]:out_.shape[-2] - pad_[1], pad_[2]:out_.shape[-1] - pad_[3]] out_radiance[..., start_y:end_y, start_x:end_x] = out_ if cuda: th.cuda.synchronize() elapsed = (time.time() - start) * 1000 LOG.info(" denoising time {:.1f} ms".format(elapsed)) out_radiance = out_radiance[0, ...].cpu().numpy().transpose([1, 2, 0]) outdir = os.path.dirname(args.output) os.makedirs(outdir, exist_ok=True) pyexr.write(args.output, out_radiance) png = args.output.replace(".exr", ".png") skio.imsave(png, (np.clip(out_radiance, 0, 1) * 255).astype(np.uint8)) shutil.rmtree(tmpdir)
def train(args): th.manual_seed(0) np.random.seed(0) color_output = False if args.task == "mnist": dataset = data.MNISTDataset(args.raster_resolution, train=True) elif args.task == "quickdraw": dataset = data.QuickDrawImageDataset( args.raster_resolution, train=True) else: raise NotImplementedError() dataloader = DataLoader( dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True) val_dataloader = None model_params = { "zdim": args.zdim, "num_strokes": args.num_strokes, "imsize": args.raster_resolution, "stroke_width": args.stroke_width, "color_output": color_output, } gen = models.Generator(**model_params) gen.train() discrim = models.Discriminator(color_output=color_output) discrim.train() if args.raster_only: vect_gen = None vect_discrim = None else: if args.generator == "fc": vect_gen = models.VectorGenerator(**model_params) elif args.generator == "bezier_fc": vect_gen = models.BezierVectorGenerator(**model_params) elif args.generator in ["rnn"]: vect_gen = models.RNNVectorGenerator(**model_params) elif args.generator in ["chain_rnn"]: vect_gen = models.ChainRNNVectorGenerator(**model_params) else: raise NotImplementedError() vect_gen.train() vect_discrim = models.Discriminator(color_output=color_output) vect_discrim.train() LOG.info("Model parameters:\n%s", model_params) device = "cpu" if th.cuda.is_available(): device = "cuda" LOG.info("Using CUDA") interface = Interface(gen, vect_gen, discrim, vect_discrim, raster_resolution=args.raster_resolution, lr=args.lr, wgan_gp=args.wgan_gp, lr_decay=args.lr_decay, device=device) env_name = args.task + "_gan" if args.raster_only: env_name += "_raster" else: env_name += "_vector" env_name += "_" + args.generator if args.wgan_gp: env_name += "_wgan" chkpt = os.path.join(OUTPUT, env_name) meta = { "model_params": model_params, "task": args.task, "generator": args.generator, } checkpointer = ttools.Checkpointer( chkpt, gen, meta=meta, optimizers=interface.optimizers, schedulers=interface.schedulers, prefix="g_") checkpointer_d = ttools.Checkpointer( chkpt, discrim, prefix="d_") # Resume from checkpoint, if any extras, _ = checkpointer.load_latest() checkpointer_d.load_latest() if not args.raster_only: checkpointer_vect = ttools.Checkpointer( chkpt, vect_gen, meta=meta, optimizers=interface.optimizers, schedulers=interface.schedulers, prefix="vect_g_") checkpointer_d_vect = ttools.Checkpointer( chkpt, vect_discrim, prefix="vect_d_") extras, _ = checkpointer_vect.load_latest() checkpointer_d_vect.load_latest() epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0 # if meta is not None and meta["model_parameters"] != model_params: # LOG.info("Checkpoint's metaparams differ " # "from CLI, aborting: %s and %s", meta, model_params) trainer = ttools.Trainer(interface) # Add callbacks losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp", "gp_vect"] training_debug = ["lr"] trainer.add_callback(Callback( env=env_name, win="samples", port=args.port, frequency=args.freq)) trainer.add_callback(ttools.callbacks.ProgressBarCallback( keys=losses, val_keys=None)) trainer.add_callback(ttools.callbacks.MultiPlotCallback( keys=losses, val_keys=None, env=env_name, port=args.port, server=args.server, base_url=args.base_url, win="losses", frequency=args.freq)) trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( keys=training_debug, smoothing=0, val_keys=None, env=env_name, server=args.server, base_url=args.base_url, port=args.port)) trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer, max_files=2, interval=600, max_epochs=10)) trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer_d, max_files=2, interval=600, max_epochs=10)) if not args.raster_only: trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer_vect, max_files=2, interval=600, max_epochs=10)) trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer_d_vect, max_files=2, interval=600, max_epochs=10)) trainer.add_callback( ttools.callbacks.LRSchedulerCallback(interface.schedulers)) # Start training trainer.train(dataloader, starting_epoch=epoch, val_dataloader=val_dataloader, num_epochs=args.num_epochs)
def main(args): # Fix seed np.random.seed(0) th.manual_seed(0) # Parameterization of the dataset (shared between train/val) data_args = dict(spp=args.spp, mode=sbmc.TilesDataset.KPCN_MODE if args.kpcn_mode else sbmc.TilesDataset.SBMC_MODE, load_coords=args.load_coords, load_gbuffer=args.load_gbuffer, load_p=args.load_p, load_ld=args.load_ld, load_bt=args.load_bt) if args.randomize_spp: if args.bs != 1: LOG.error( "Training with randomized spp is only valid for" "batch_size=1, got %d", args.bs) raise RuntimeError("Incorrect batch size") data = sbmc.MultiSampleCountDataset(args.data, **data_args) LOG.info("Training with randomized sample count in [%d, %d]" % (2, args.spp)) else: data = sbmc.TilesDataset(args.data, **data_args) LOG.info("Training with a single sample count: %dspp" % args.spp) if args.kpcn_mode: LOG.info("Model: pixel-based comparison from [Bako2017]") model = sbmc.KPCN(data.num_features, ksize=args.ksize) model_params = dict(ksize=args.ksize) else: LOG.info("Model: sample-based [Gharbi2019]") model = sbmc.Multisteps(data.num_features, data.num_global_features, ksize=args.ksize, splat=not args.gather, pixel=args.pixel) model_params = dict(ksize=args.ksize, gather=args.gather, pixel=args.pixel) dataloader = DataLoader(data, batch_size=args.bs, num_workers=args.num_worker_threads, shuffle=True) # Validation set uses a constant spp val_dataloader = None if args.val_data is not None: LOG.info("Validation set with %dspp" % args.spp) val_data = sbmc.TilesDataset(args.val_data, **data_args) val_dataloader = DataLoader(val_data, batch_size=args.bs, num_workers=1, shuffle=False) else: LOG.info("No validation set provided") meta = dict(model_params=model_params, kpcn_mode=args.kpcn_mode, data_params=data_args) LOG.info("Model configuration: {}".format(model_params)) checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta) interface = sbmc.SampleBasedDenoiserInterface(model, lr=args.lr, cuda=args.cuda) extras, meta = checkpointer.load_latest() trainer = ttools.Trainer(interface) # Hook-up some callbacks to the training loop log_keys = ["loss", "rmse"] trainer.add_callback(ttools.callbacks.ProgressBarCallback(log_keys)) trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer)) trainer.add_callback( ttools.callbacks.VisdomLoggingCallback(log_keys, env=args.env, port=args.port, log=True, frequency=100)) trainer.add_callback( sbmc.DenoisingDisplayCallback(env=args.env, port=args.port, win="images")) # Launch the training LOG.info("Training started, 'Ctrl+C' to abort.") trainer.train(dataloader, num_epochs=args.num_epochs, val_dataloader=val_dataloader)
def main(args): """Entrypoint to the training.""" # Load model parameters from checkpoint, if any meta = ttools.Checkpointer.load_meta(args.checkpoint_dir) if meta is None: LOG.warning("No checkpoint found at %s, aborting.", args.checkpoint_dir) return data = demosaicnet.Dataset(args.data, download=False, mode=meta["mode"], subset=demosaicnet.TEST_SUBSET) dataloader = DataLoader(data, batch_size=1, num_workers=4, pin_memory=True, shuffle=True) if meta["mode"] == demosaicnet.BAYER_MODE: model = demosaicnet.BayerDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) elif meta["mode"] == demosaicnet.XTRANS_MODE: model = demosaicnet.XTransDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta) checkpointer.load_latest() # Resume from checkpoint, if any. # No need for gradients for p in model.parameters(): p.requires_grad = False mse_fn = th.nn.MSELoss() psnr_fn = PSNR() device = "cpu" if th.cuda.is_available(): device = "cuda" LOG.info("Using CUDA") count = 0 mse = 0.0 psnr = 0.0 for idx, batch in enumerate(dataloader): mosaic = batch[0].to(device) target = batch[1].to(device) output = model(mosaic) target = crop_like(target, output) output = th.clamp(output, 0, 1) psnr_ = psnr_fn(output, target).item() mse_ = mse_fn(output, target).item() psnr += psnr_ mse += mse_ count += 1 LOG.info("Image %04d, PSNR = %.1f dB, MSE = %.5f", idx, psnr_, mse_) mse /= count psnr /= count LOG.info("-----------------------------------") LOG.info("Average, PSNR = %.1f dB, MSE = %.5f", psnr, mse)
def main(args): log.info("Loading model {}".format(args.checkpoint)) meta_params = ttools.Checkpointer.load_meta(args.checkpoint) spp = meta_params["spp"] use_p = meta_params["use_p"] use_ld = meta_params["use_ld"] use_bt = meta_params["use_bt"] # use_coc = meta_params["use_coc"] mode = "sample" if "DisneyPreprocessor" == meta_params["preprocessor"]: mode = "disney_pixel" elif "SampleDisneyPreprocessor" == meta_params["preprocessor"]: mode = "disney_sample" log.info("Rendering at {} spp".format(spp)) log.info("Setting up dataloader, p:{} bt:{} ld:{}".format(use_p, use_bt, use_ld)) data = dset.FullImageDataset(args.data, dset.RenderDataset, spp=spp, use_p=use_p, use_ld=use_ld, use_bt=use_bt) preprocessor = pre.get(meta_params["preprocessor"])(data) xforms = transforms.Compose([dset.ToTensor(), preprocessor]) data.transform = xforms dataloader = DataLoader(data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) model = models.get(preprocessor, meta_params["model_params"]) model.cuda() model.train(False) checkpointer = ttools.Checkpointer(args.checkpoint, model, None) extras, meta = checkpointer.load_latest() log.info("Loading latest checkpoint {}".format("failed" if meta is None else "success")) for scene_id, batch in enumerate(dataloader): batch_v = make_variable(batch, cuda=True) with th.no_grad(): klist = [] out_ = model(batch_v, kernel_list=klist) lowspp = batch["radiance"] target = batch["target_image"] out = out_["radiance"] cx = 70 cy = 20 c = 128 target = crop_like(target, out) lowspp = crop_like(lowspp.squeeze(), out) lowspp = lowspp[..., cy:cy+c, cx:cx+c] lowspp = lowspp.permute(1, 2, 0, 3) chan, h, w, s = lowspp.shape lowspp = lowspp.contiguous().view(chan, h, w*s) sum_r = [] sum_w = [] max_w = [] maxi = crop_like(klist[-1]["max_w"].unsqueeze(1), out) kernels = [] updated_kernels = [] for k in klist: kernels.append(th.exp(crop_like(k["kernels"], out)-maxi)) updated_kernels.append(th.exp(crop_like(k["updated_kernels"], out)-maxi)) out = out[..., cy:cy+c, cx:cx+c] target = target[..., cy:cy+c, cx:cx+c] updated_kernels = [k[..., cy:cy+c, cx:cx+c] for k in updated_kernels] kernels = [k[..., cy:cy+c, cx:cx+c] for k in kernels] u_kernels_im = viz.kernels2im(kernels) kmean = u_kernels_im.mean(0) kvar = u_kernels_im.std(0) n, h, w = u_kernels_im.shape u_kernels_im = u_kernels_im.permute(1, 0, 2).contiguous().view(h, w*n) fname = os.path.join(args.output, "lowspp.png") save(fname, lowspp) fname = os.path.join(args.output, "target.png") save(fname, target) fname = os.path.join(args.output, "output.png") save(fname, out) fname = os.path.join(args.output, "kernels_gather.png") save(fname, u_kernels_im) fname = os.path.join(args.output, "kernels_variance.png") print(kvar.max()) save(fname, kvar) import ipdb; ipdb.set_trace() break
def train(args): th.manual_seed(0) np.random.seed(0) dataset = data.FixedLengthQuickDrawDataset( args.dataset, max_seq_length=args.sequence_length, canvas_size=args.raster_resolution) dataloader = DataLoader(dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True) # val_dataset = [s for idx, s in enumerate(dataset) if idx < 8] # val_dataloader = DataLoader( # val_dataset, batch_size=8, num_workers=4, shuffle=False) val_dataloader = None model_params = { "zdim": args.zdim, "sequence_length": args.sequence_length, "image_size": args.raster_resolution, # "encoder_dim": args.encoder_dim, # "decoder_dim": args.decoder_dim, } model = SketchVAE(**model_params) model.train() LOG.info("Model parameters:\n%s", model_params) device = "cpu" if th.cuda.is_available(): device = "cuda" LOG.info("Using CUDA") interface = Interface(model, raster_resolution=args.raster_resolution, lr=args.lr, lr_decay=args.lr_decay, kl_decay=args.kl_decay, kl_weight=args.kl_weight, absolute_coords=args.absolute_coordinates, device=device) env_name = "sketch_vae" if args.custom_name is not None: env_name += "_" + args.custom_name if args.absolute_coordinates: env_name += "_abs_coords" chkpt = os.path.join(OUTPUT, env_name) # Resume from checkpoint, if any checkpointer = ttools.Checkpointer(chkpt, model, meta=model_params, optimizers=interface.optimizers(), schedulers=interface.schedulers) extras, meta = checkpointer.load_latest() epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0 if meta is not None and meta != model_params: LOG.info( "Checkpoint's metaparams differ " "from CLI, aborting: %s and %s", meta, model_params) trainer = ttools.Trainer(interface) # Add callbacks losses = ["loss", "kl_loss", "vae_im_loss", "sketch_im_loss"] training_debug = ["lr", "kl_weight"] trainer.add_callback( ttools.callbacks.ProgressBarCallback(keys=losses, val_keys=None)) trainer.add_callback( ttools.callbacks.VisdomLoggingCallback(keys=losses, val_keys=None, env=env_name, port=args.port)) trainer.add_callback( ttools.callbacks.VisdomLoggingCallback(keys=training_debug, smoothing=0, val_keys=None, env=env_name, port=args.port)) trainer.add_callback( ttools.callbacks.CheckpointingCallback(checkpointer, max_files=2, interval=600, max_epochs=10)) trainer.add_callback( ttools.callbacks.LRSchedulerCallback(interface.schedulers)) trainer.add_callback( SketchVAECallback(env=env_name, win="samples", port=args.port, frequency=args.freq)) # Start training trainer.train(dataloader, starting_epoch=epoch, val_dataloader=val_dataloader, num_epochs=args.num_epochs)
def run(args): th.manual_seed(0) np.random.seed(0) meta = ttools.Checkpointer.load_meta(args.model, "vect_g_") if meta is None: LOG.warning("Could not load metadata at %s, aborting.", args.model) return LOG.info("Loaded model %s with metadata:\n %s", args.model, meta) if args.output_dir is None: outdir = os.path.join(args.model, "eval") else: outdir = args.output_dir os.makedirs(outdir, exist_ok=True) model_params = meta["model_params"] if args.imsize is not None: LOG.info("Overriding output image size to: %dx%d", args.imsize, args.imsize) old_size = model_params["imsize"] scale = args.imsize * 1.0 / old_size model_params["imsize"] = args.imsize model_params["stroke_width"] = [ w * scale for w in model_params["stroke_width"] ] LOG.info("Overriding width to: %s", model_params["stroke_width"]) # task = meta["task"] generator = meta["generator"] if generator == "fc": model = models.VectorGenerator(**model_params) elif generator == "bezier_fc": model = models.BezierVectorGenerator(**model_params) elif generator in ["rnn"]: model = models.RNNVectorGenerator(**model_params) elif generator in ["chain_rnn"]: model = models.ChainRNNVectorGenerator(**model_params) else: raise NotImplementedError() model.eval() device = "cpu" if th.cuda.is_available(): device = "cuda" model.to(device) checkpointer = ttools.Checkpointer(args.model, model, meta=meta, prefix="vect_g_") checkpointer.load_latest() LOG.info("Computing latent space interpolation") for i in range(args.nsamples): z0 = model.sample_z(1) z1 = model.sample_z(1) # interpolation alpha = th.linspace(0, 1, args.nsteps).view(args.nsteps, 1).to(device) alpha_video = th.linspace(0, 1, args.nframes).view(args.nframes, 1) alpha_video = alpha_video.to(device) length = [args.nsteps, args.nframes] for idx, a in enumerate([alpha, alpha_video]): _z0 = z0.repeat(length[idx], 1).to(device) _z1 = z1.repeat(length[idx], 1).to(device) batch = _z0 * (1 - a) + _z1 * a out = model(batch) if idx == 0: # image viz n, c, h, w = out.shape out = out.permute(1, 2, 0, 3) out = out.contiguous().view(1, c, h, w * n) out = postprocess(out, invert=args.invert) imsave(out, os.path.join(outdir, "latent_interp", "%03d.png" % i)) scenes = model.get_vector(batch) for scn_idx, scn in enumerate(scenes): save_scene( scn, os.path.join(outdir, "latent_interp_svg", "%03d" % i, "%03d.svg" % scn_idx)) else: # video viz anim_root = os.path.join(outdir, "latent_interp_video", "%03d" % i) LOG.info("Rendering animation %d", i) for frame_idx, frame in enumerate(out): LOG.info("frame %d", frame_idx) frame = frame.unsqueeze(0) frame = postprocess(frame, invert=args.invert) imsave( frame, os.path.join(anim_root, "frame%04d.png" % frame_idx)) call([ "ffmpeg", "-framerate", "30", "-i", os.path.join(anim_root, "frame%04d.png"), "-vb", "20M", os.path.join(outdir, "latent_interp_video", "%03d.mp4" % i) ]) LOG.info(" saved %d", i) LOG.info("Sampling latent space") for i in range(args.nsamples): n = 8 bs = n * n z = model.sample_z(bs).to(device) out = model(z) _, c, h, w = out.shape out = out.view(n, n, c, h, w).permute(2, 0, 3, 1, 4) out = out.contiguous().view(1, c, h * n, w * n) out = postprocess(out) imsave(out, os.path.join(outdir, "samples_%03d.png" % i)) LOG.info(" saved %d", i) LOG.info("output images saved to %s", outdir)
def main(args): """Entrypoint to the training.""" # Load model parameters from checkpoint, if any meta = ttools.Checkpointer.load_meta(args.checkpoint_dir) if meta is None: LOG.info("No metadata or checkpoint, " "parsing model parameters from command line.") meta = { "depth": args.depth, "width": args.width, "mode": args.mode, } data = demosaicnet.Dataset(args.data, download=False, mode=meta["mode"], subset=demosaicnet.TRAIN_SUBSET) dataloader = DataLoader(data, batch_size=args.bs, num_workers=args.num_worker_threads, pin_memory=True, shuffle=True) val_dataloader = None if args.val_data: val_data = demosaicnet.Dataset(args.data, download=False, mode=meta["mode"], subset=demosaicnet.VAL_SUBSET) val_dataloader = DataLoader(val_data, batch_size=args.bs, num_workers=1, pin_memory=True, shuffle=False) if meta["mode"] == demosaicnet.BAYER_MODE: model = demosaicnet.BayerDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) elif meta["mode"] == demosaicnet.XTRANS_MODE: model = demosaicnet.XTransDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta) interface = DemosaicnetInterface(model, lr=args.lr, cuda=args.cuda) checkpointer.load_latest() # Resume from checkpoint, if any. trainer = ttools.Trainer(interface) keys = ["loss", "psnr"] val_keys = ["psnr"] trainer.add_callback( ttools.callbacks.ProgressBarCallback(keys=keys, val_keys=val_keys)) trainer.add_callback( ttools.callbacks.VisdomLoggingCallback(keys=keys, val_keys=val_keys, server=args.server, env=args.env, port=args.port)) trainer.add_callback( ImageCallback(server=args.server, env=args.env, win="images", port=args.port)) trainer.add_callback( ttools.callbacks.CheckpointingCallback(checkpointer, max_files=8, interval=3600, max_epochs=10)) if args.cuda: LOG.info("Training with CUDA enabled") else: LOG.info("Training on CPU") trainer.train(dataloader, num_epochs=args.num_epochs, val_dataloader=val_dataloader)