Ejemplo n.º 1
0
def train(args):
    """Train a MNIST classifier."""

    # Setup train and val data
    _xform = xforms.Compose([xforms.Resize([32, 32]), xforms.ToTensor()])
    data = MNIST("data/mnist", train=True, download=True, transform=_xform)

    # Initialize asynchronous dataloaders
    loader = DataLoader(data, batch_size=args.bs, num_workers=2)

    # Instantiate the models
    gen = Generator()
    discrim = Discriminator()

    gen.apply(weights_init_normal)
    discrim.apply(weights_init_normal)

    # Checkpointer to save/recall model parameters
    checkpointer_gen = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=gen, prefix="gen_")
    checkpointer_discrim = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=discrim, prefix="discrim_")

    # resume from a previous checkpoint, if any
    checkpointer_gen.load_latest()
    checkpointer_discrim.load_latest()

    # Setup a training interface for the model
    interface = MNISTInterface(gen, discrim, lr=args.lr)

    # Create a training looper with the interface we defined
    trainer = ttools.Trainer(interface)

    # Adds several callbacks, that will be called by the trainer --------------
    # A periodic checkpointing operation
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_gen))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_discrim))
    # A simple progress bar
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(
        keys=["loss_g", "loss_d", "loss"]))
    # A volatile logging using visdom
    trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
        keys=["loss_g", "loss_d", "loss"],
        port=8080, env="mnist_demo"))
    # Image
    trainer.add_callback(VisdomImageCallback(port=8080, env="mnist_demo"))
    # -------------------------------------------------------------------------

    # Start the training
    LOG.info("Training started, press Ctrl-C to interrupt.")
    trainer.train(loader, num_epochs=args.epochs)
Ejemplo n.º 2
0
def main(args):
    data = datasets.ShapenetDataset(args.data, args.canvas_size)
    dataloader = DataLoader(data, batch_size=args.bs, num_workers=args.num_worker_threads,
                            worker_init_fn=_worker_init_fn, shuffle=True, drop_last=True)
    LOG.info(data)

    val_data = datasets.ShapenetDataset(args.data, args.canvas_size, val=True)
    val_dataloader = DataLoader(val_data)

    model = PrimsModel(output_dim=(11 if args.rounded else 10)*args.n_primitives)

    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model)
    checkpointer.load_latest()

    interface = VectorizerInterface(model, args.lr, args.n_primitives, args.canvas_size, args.w_surface,
                                    args.w_alignment, args.csg, args.rounded, cuda=args.cuda)

    keys = ['loss', 'surfaceloss', 'alignmentloss']

    writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries',
                                        datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')), flush_secs=1)
    val_writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries',
                                            datetime.datetime.now().strftime('val-%m%d%y-%H%M%S')), flush_secs=1)

    trainer = ttools.Trainer(interface)
    trainer.add_callback(ttools.callbacks.TensorBoardLoggingCallback(keys=keys, writer=writer,
                                                                     val_writer=val_writer, frequency=5))
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(keys=keys))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer, max_files=1))
    trainer.train(dataloader, num_epochs=args.num_epochs, val_dataloader=val_dataloader)
Ejemplo n.º 3
0
def interpolate(args):
    chkpt = VAE_OUTPUT
    if args.conditional:
        chkpt += "_conditional"
    if args.fc:
        chkpt += "_fc"

    meta = ttools.Checkpointer.load_meta(chkpt, prefix="g_")
    if meta is None:
        LOG.info("No metadata in checkpoint (or no checkpoint), aborting.")
        return

    model = VectorMNISTVAE(imsize=128, **meta)
    checkpointer = ttools.Checkpointer(chkpt, model, prefix="g_")
    checkpointer.load_latest()
    model.eval()

    # Sample some latent vectors
    bs = 10
    z = th.randn(bs, model.zdim)

    label = None
    label = th.arange(0, 10)

    animation = []
    nframes = 60
    with th.no_grad():
        for idx, _z in enumerate(z):
            if idx == 0:  # skip first
                continue
            _z0 = z[idx-1].unsqueeze(0).repeat(nframes, 1)
            _z = _z.unsqueeze(0).repeat(nframes, 1)
            if args.conditional:
                _label = label[idx].unsqueeze(0).repeat(nframes)
            else:
                _label = None

            # interp weights
            alpha = th.linspace(0, 1, nframes).view(nframes,  1)
            batch = alpha*_z + (1.0 - alpha)*_z0
            images, aux = model.decode(batch, label=_label)
            images += 1.0
            images /= 2.0
            animation.append(images)

    anim_dir = os.path.join(chkpt, "interpolation")
    os.makedirs(anim_dir, exist_ok=True)
    animation = th.cat(animation, 0)
    for idx, frame in enumerate(animation):
        frame = frame.squeeze()
        frame = th.clamp(frame, 0, 1).cpu().numpy()
        path = os.path.join(anim_dir, "frame%03d.png" % idx)
        pydiffvg.imwrite(frame, path, gamma=2.2)

    LOG.info("Results saved to %s", anim_dir)
Ejemplo n.º 4
0
def train(args):
    """Train a MNIST classifier."""

    # Setup train and val data
    _xform = xforms.Compose([xforms.Resize([32, 32]), xforms.ToTensor()])
    data = MNIST("data/mnist", train=True, download=True, transform=_xform)
    val_data = MNIST("data/mnist", train=False, transform=_xform)

    # Initialize asynchronous dataloaders
    loader = DataLoader(data, batch_size=args.bs, num_workers=2)
    val_loader = DataLoader(val_data, batch_size=16, num_workers=1)

    # Instantiate a model
    model = MNISTClassifier()

    # Checkpointer to save/recall model parameters
    checkpointer = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=model, prefix="classifier_")

    # resume from a previous checkpoint, if any
    checkpointer.load_latest()

    # Setup a training interface for the model
    if th.cuda.is_available():
        device = th.device("cuda")
    else:
        device = th.device("cpu")
    interface = MNISTInterface(model, device, lr=args.lr)

    # Create a training looper with the interface we defined
    trainer = ttools.Trainer(interface)

    # Adds several callbacks, that will be called by the trainer --------------
    # A periodic checkpointing operation
    LOG.info("This demo uses a Visdom to display the loss and accuracy, make sure you have a visdom server running! ('make visdom_server')")
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer))
    # A simple progress bar
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(
        keys=["loss", "accuracy"], val_keys=["loss", "accuracy"]))
    # A volatile logging using visdom
    trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
        keys=["loss", "accuracy"], val_keys=["loss", "accuracy"],
        port=8080, env="mnist_demo"))
    # -------------------------------------------------------------------------

    # Start the training
    LOG.info("Training started, press Ctrl-C to interrupt.")
    trainer.train(loader, num_epochs=args.epochs, val_dataloader=val_loader)
Ejemplo n.º 5
0
def main(args):
    device = "cuda" if th.cuda.is_available() and args.cuda else "cpu"

    model = CurvesModel(sum(templates.topology))
    model.to(device)
    model.eval()
    checkpointer = ttools.Checkpointer(f'models/{args.model}', model)
    extras, _ = checkpointer.load_latest()
    if extras is not None:
        print(f"Loaded checkpoint (epoch {extras['epoch']})")
    else:
        print("Unable to load checkpoint")

    im = to_tensor(Image.open(args.image).convert('L').resize(
        (128, 128))).to(device)
    z = th.zeros(len(string.ascii_uppercase)).scatter_(
        0, th.tensor(string.ascii_uppercase.index(args.letter)), 1).to(device)

    print(f"Processing image {args.image} (letter {args.letter})")

    curves = model(im[None], z[None])['curves'][0].detach().cpu()

    surface = cairo.PDFSurface(args.out, 128, 128)
    ctx = cairo.Context(surface)
    ctx.scale(128, 128)
    ctx.rectangle(0, 0, 1, 1)
    ctx.set_source_rgb(1, 1, 1)
    ctx.fill()

    ctx.save()
    im = cairo.ImageSurface.create_from_png(args.image)
    ctx.scale(1 / 128, 1 / 128)
    ctx.set_source_surface(im)
    ctx.paint()
    ctx.restore()

    draw_curves(curves, templates.n_loops[args.letter], ctx)

    surface.finish()

    print(f"Output saved to {args.out}")
Ejemplo n.º 6
0
def main(args):
    data = datasets.FontsDataset(args.data, args.chamfer, args.n_samples_per_curve)
    dataloader = DataLoader(data, batch_size=args.bs, num_workers=args.num_worker_threads,
                            worker_init_fn=_worker_init_fn, shuffle=True, drop_last=True)
    LOG.info(data)

    val_data = datasets.FontsDataset(args.data, args.chamfer, args.n_samples_per_curve, val=True)
    val_dataloader = DataLoader(val_data)

    model = CurvesModel(n_curves=sum(templates.topology))

    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model)
    checkpointer.load_latest()

    interface = VectorizerInterface(model, args.simple_templates, args.lr, args.max_stroke, args.canvas_size,
                                    args.chamfer, args.n_samples_per_curve, args.w_surface, args.w_template,
                                    args.w_alignment, cuda=args.cuda)

    keys = ['loss', 'chamferloss', 'templateloss'] if args.chamfer \
        else ['loss', 'surfaceloss', 'alignmentloss', 'templateloss']

    writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries',
                                        datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')), flush_secs=1)
    val_writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries',
                                            datetime.datetime.now().strftime('val-%m%d%y-%H%M%S')), flush_secs=1)

    trainer = ttools.Trainer(interface)
    trainer.add_callback(ttools.callbacks.TensorBoardLoggingCallback(keys=keys, writer=writer,
                                                                     val_writer=val_writer, frequency=5))
    trainer.add_callback(callbacks.InputImageCallback(writer=writer, val_writer=val_writer, frequency=100))
    trainer.add_callback(callbacks.CurvesCallback(writer=writer, val_writer=val_writer, frequency=100))
    if not args.chamfer:
        trainer.add_callback(callbacks.RenderingCallback(writer=writer, val_writer=val_writer, frequency=100))
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(keys=keys))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer, max_files=1))
    trainer.train(dataloader, num_epochs=args.num_epochs, val_dataloader=val_dataloader)
Ejemplo n.º 7
0
def train(args):
    th.manual_seed(0)
    np.random.seed(0)

    dataset = data.QuickDrawDataset(args.dataset)
    dataloader = DataLoader(dataset,
                            batch_size=args.bs,
                            num_workers=4,
                            shuffle=True,
                            pin_memory=False)

    val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
    val_dataloader = DataLoader(val_dataset,
                                batch_size=8,
                                num_workers=4,
                                shuffle=False,
                                pin_memory=False)

    model_params = {
        "zdim": args.zdim,
        "num_gaussians": args.num_gaussians,
        "encoder_dim": args.encoder_dim,
        "decoder_dim": args.decoder_dim,
    }
    model = SketchRNN(**model_params)
    model.train()

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        LOG.info("Using CUDA")

    interface = Interface(model,
                          lr=args.lr,
                          lr_decay=args.lr_decay,
                          kl_decay=args.kl_decay,
                          kl_weight=args.kl_weight,
                          sampling_temperature=args.sampling_temperature,
                          device=device)

    chkpt = OUTPUT_BASELINE
    env_name = "sketch_rnn"

    # Resume from checkpoint, if any
    checkpointer = ttools.Checkpointer(chkpt,
                                       model,
                                       meta=model_params,
                                       optimizers=interface.optimizers(),
                                       schedulers=interface.schedulers)
    extras, meta = checkpointer.load_latest()
    epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0

    if meta is not None and meta != model_params:
        LOG.info(
            "Checkpoint's metaparams differ "
            "from CLI, aborting: %s and %s", meta, model_params)

    trainer = ttools.Trainer(interface)

    # Add callbacks
    losses = ["loss", "kl_loss", "recons_loss"]
    training_debug = ["lr", "kl_weight"]
    trainer.add_callback(
        ttools.callbacks.ProgressBarCallback(keys=losses, val_keys=None))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=losses,
                                               val_keys=None,
                                               env=env_name,
                                               port=args.port))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=training_debug,
                                               smoothing=0,
                                               val_keys=None,
                                               env=env_name,
                                               port=args.port))
    trainer.add_callback(
        ttools.callbacks.CheckpointingCallback(checkpointer,
                                               max_files=2,
                                               interval=600,
                                               max_epochs=10))
    trainer.add_callback(
        ttools.callbacks.LRSchedulerCallback(interface.schedulers))

    trainer.add_callback(
        SketchRNNCallback(env=env_name,
                          win="samples",
                          port=args.port,
                          frequency=args.freq))

    # Start training
    trainer.train(dataloader,
                  starting_epoch=epoch,
                  val_dataloader=val_dataloader,
                  num_epochs=args.num_epochs)
Ejemplo n.º 8
0
def generate_samples(args):
    chkpt = VAE_OUTPUT
    if args.conditional:
        chkpt += "_conditional"
    if args.fc:
        chkpt += "_fc"

    meta = ttools.Checkpointer.load_meta(chkpt, prefix="g_")
    if meta is None:
        LOG.info("No metadata in checkpoint (or no checkpoint), aborting.")
        return

    model = VectorMNISTVAE(**meta)
    checkpointer = ttools.Checkpointer(chkpt, model, prefix="g_")
    checkpointer.load_latest()
    model.eval()

    # Sample some latent vectors
    n = 8
    bs = n*n
    z = th.randn(bs, model.zdim)

    imsize = 28
    dataset = Dataset(args.data_dir, imsize)
    dataloader = DataLoader(dataset, batch_size=bs,
                            num_workers=1, shuffle=True)

    for batch in dataloader:
        ref, label = batch
        break

    autoencode = True
    if autoencode:
        LOG.info("Sampling with auto-encoder code")
        if not args.conditional:
            label = None
        mu, logvar = model.encode(ref, label)
        z = model.reparameterize(mu, logvar)
    else:
        label = None
        if args.conditional:
            label = th.clamp(th.rand(bs)*10, 0, 9).long()
            if args.digit is not None:
                label[:] = args.digit

    with th.no_grad():
        images, aux = model.decode(z, label=label)
        scenes = aux["scenes"]
    images += 1.0
    images /= 2.0

    h = w = model.imsize

    images = images.view(n, n, h, w).permute(0, 2, 1, 3)
    images = images.contiguous().view(n*h, n*w)
    images = th.clamp(images, 0, 1).cpu().numpy()
    path = os.path.join(chkpt, "samples.png")
    pydiffvg.imwrite(images, path, gamma=2.2)

    if autoencode:
        ref += 1.0
        ref /= 2.0
        ref = ref.view(n, n, h, w).permute(0, 2, 1, 3)
        ref = ref.contiguous().view(n*h, n*w)
        ref = th.clamp(ref, 0, 1).cpu().numpy()
        path = os.path.join(chkpt, "ref.png")
        pydiffvg.imwrite(ref, path, gamma=2.2)

    # merge scenes
    all_shapes = []
    all_shape_groups = []
    cur_id = 0
    for idx, s in enumerate(scenes):
        shapes, shape_groups, _ = s
        # width, height = sizes

        # Shift digit on canvas
        center_x = idx % n
        center_y = idx // n
        for shape in shapes:
            shape.points[:, 0] += center_x * model.imsize
            shape.points[:, 1] += center_y * model.imsize
            all_shapes.append(shape)
        for grp in shape_groups:
            grp.shape_ids[:] = cur_id
            cur_id += 1
            all_shape_groups.append(grp)

    LOG.info("Generated %d shapes", len(all_shapes))

    fname = os.path.join(chkpt, "digits.svg")
    pydiffvg.save_svg(fname, n*model.imsize, n*model.imsize, all_shapes,
                      all_shape_groups, use_gamma=False)

    LOG.info("Results saved to %s", chkpt)
Ejemplo n.º 9
0
def train(args):
    th.manual_seed(0)
    np.random.seed(0)

    pydiffvg.set_use_gpu(args.cuda)

    # Initialize datasets
    imsize = 28
    dataset = Dataset(args.data_dir, imsize)
    dataloader = DataLoader(dataset, batch_size=args.bs,
                            num_workers=4, shuffle=True)

    if args.generator in ["vae", "ae"]:
        LOG.info("Vector config:\n  samples %d\n"
                 "  paths: %d\n  segments: %d\n"
                 "  zdim: %d\n"
                 "  conditional: %d\n"
                 "  fc: %d\n",
                 args.samples, args.paths, args.segments,
                 args.zdim, args.conditional, args.fc)

    model_params = dict(samples=args.samples, paths=args.paths,
                        segments=args.segments, conditional=args.conditional,
                        zdim=args.zdim, fc=args.fc)

    if args.generator == "vae":
        model = VectorMNISTVAE(variational=True, **model_params)
        chkpt = VAE_OUTPUT
        name = "mnist_vae"
    elif args.generator == "ae":
        model = VectorMNISTVAE(variational=False, **model_params)
        chkpt = AE_OUTPUT
        name = "mnist_ae"
    else:
        raise ValueError("unknown generator")

    if args.conditional:
        name += "_conditional"
        chkpt += "_conditional"

    if args.fc:
        name += "_fc"
        chkpt += "_fc"

    # Resume from checkpoint, if any
    checkpointer = ttools.Checkpointer(
        chkpt, model, meta=model_params, prefix="g_")
    extras, meta = checkpointer.load_latest()

    if meta is not None and meta != model_params:
        LOG.info("Checkpoint's metaparams differ from CLI, aborting: %s and %s",
                 meta, model_params)

    # Hook interface
    if args.generator in ["vae", "ae"]:
        variational = args.generator == "vae"
        if variational:
            LOG.info("Using a VAE")
        else:
            LOG.info("Using an AE")
        interface = VAEInterface(model, lr=args.lr, cuda=args.cuda,
                                 variational=variational, w_kld=args.kld_weight)

    trainer = ttools.Trainer(interface)

    # Add callbacks
    keys = ["loss_g", "loss_d"]
    if args.generator == "vae":
        keys = ["kld", "data_loss", "loss"]
    elif args.generator == "ae":
        keys = ["data_loss", "loss"]
    port = 8097
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(
        keys=keys, val_keys=keys))
    trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
        keys=keys, val_keys=keys, env=name, port=port))
    trainer.add_callback(MNISTCallback(
        env=name, win="samples", port=port, frequency=args.freq))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(
        checkpointer, max_files=2, interval=600, max_epochs=50))

    # Start training
    trainer.train(dataloader, num_epochs=args.num_epochs)
Ejemplo n.º 10
0
def main(args):
    start = time.time()
    if not os.path.exists(args.input):
        raise ValueError("input {} does not exist".format(args.input))

    data_root = os.path.abspath(args.input)
    name = os.path.basename(data_root)
    tmpdir = tempfile.mkdtemp()
    os.symlink(data_root, os.path.join(tmpdir, name))

    LOG.info("Loading model {}".format(args.checkpoint))
    meta_params = ttools.Checkpointer.load_meta(args.checkpoint)

    LOG.info("Setting up dataloader")
    data_params = meta_params["data_params"]
    if args.spp:
        data_params["spp"] = args.spp
    data = sbmc.FullImagesDataset(tmpdir, **data_params)
    dataloader = DataLoader(data, batch_size=1, shuffle=False, num_workers=0)

    LOG.info("Denoising input with {} spp".format(data_params["spp"]))

    kpcn_mode = meta_params["kpcn_mode"]
    if kpcn_mode:
        LOG.info("Using [Bako2017] denoiser.")
        model = sbmc.KPCN(data.num_features)
    else:
        model = sbmc.Multisteps(data.num_features, data.num_global_features)

    model.train(False)
    device = "cpu"
    cuda = th.cuda.is_available()
    if cuda:
        LOG.info("Using CUDA")
        model.cuda()
        device = "cuda"

    checkpointer = ttools.Checkpointer(args.checkpoint, model, None)
    extras, meta = checkpointer.load_latest()
    LOG.info("Loading latest checkpoint {}".format(
        "failed" if meta is None else "success"))

    elapsed = (time.time() - start) * 1000
    LOG.info("setup time {:.1f} ms".format(elapsed))

    LOG.info("starting the denoiser")
    for scene_id, batch in enumerate(dataloader):
        for k in batch.keys():
            batch[k] = batch[k].to(device)
        scene = os.path.basename(data.scenes[scene_id])
        LOG.info("  scene {}".format(scene))
        tile_sz = args.tile_size
        tile_pad = args.tile_pad
        batch_parts = _split_tiles(batch, max_sz=tile_sz, pad=tile_pad)
        out_radiance = th.zeros_like(batch["low_spp"])

        if cuda:
            th.cuda.synchronize()
        start = time.time()
        for part, start_y, end_y, start_x, end_x, pad_ in batch_parts:
            with th.no_grad():
                out_ = model(part)
                out_ = _pad(part, out_["radiance"], kpcn_mode)
                out_ = out_[..., pad_[0]:out_.shape[-2] - pad_[1],
                            pad_[2]:out_.shape[-1] - pad_[3]]
                out_radiance[..., start_y:end_y, start_x:end_x] = out_
        if cuda:
            th.cuda.synchronize()
        elapsed = (time.time() - start) * 1000
        LOG.info("    denoising time {:.1f} ms".format(elapsed))

        out_radiance = out_radiance[0, ...].cpu().numpy().transpose([1, 2, 0])
        outdir = os.path.dirname(args.output)
        os.makedirs(outdir, exist_ok=True)
        pyexr.write(args.output, out_radiance)

        png = args.output.replace(".exr", ".png")
        skio.imsave(png, (np.clip(out_radiance, 0, 1) * 255).astype(np.uint8))
    shutil.rmtree(tmpdir)
Ejemplo n.º 11
0
def train(args):
    th.manual_seed(0)
    np.random.seed(0)

    color_output = False
    if args.task == "mnist":
        dataset = data.MNISTDataset(args.raster_resolution, train=True)
    elif args.task == "quickdraw":
        dataset = data.QuickDrawImageDataset(
            args.raster_resolution, train=True)
    else:
        raise NotImplementedError()

    dataloader = DataLoader(
        dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)

    val_dataloader = None

    model_params = {
        "zdim": args.zdim,
        "num_strokes": args.num_strokes,
        "imsize": args.raster_resolution,
        "stroke_width": args.stroke_width,
        "color_output": color_output,
    }
    gen = models.Generator(**model_params)
    gen.train()

    discrim = models.Discriminator(color_output=color_output)
    discrim.train()

    if args.raster_only:
        vect_gen = None
        vect_discrim = None
    else:
        if args.generator == "fc":
            vect_gen = models.VectorGenerator(**model_params)
        elif args.generator == "bezier_fc":
            vect_gen = models.BezierVectorGenerator(**model_params)
        elif args.generator in ["rnn"]:
            vect_gen = models.RNNVectorGenerator(**model_params)
        elif args.generator in ["chain_rnn"]:
            vect_gen = models.ChainRNNVectorGenerator(**model_params)
        else:
            raise NotImplementedError()
        vect_gen.train()

        vect_discrim = models.Discriminator(color_output=color_output)
        vect_discrim.train()

    LOG.info("Model parameters:\n%s", model_params)

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        LOG.info("Using CUDA")

    interface = Interface(gen, vect_gen, discrim, vect_discrim,
                          raster_resolution=args.raster_resolution, lr=args.lr,
                          wgan_gp=args.wgan_gp,
                          lr_decay=args.lr_decay, device=device)

    env_name = args.task + "_gan"

    if args.raster_only:
        env_name += "_raster"
    else:
        env_name += "_vector"

    env_name += "_" + args.generator

    if args.wgan_gp:
        env_name += "_wgan"

    chkpt = os.path.join(OUTPUT, env_name)

    meta = {
        "model_params": model_params,
        "task": args.task,
        "generator": args.generator,
    }
    checkpointer = ttools.Checkpointer(
        chkpt, gen, meta=meta,
        optimizers=interface.optimizers,
        schedulers=interface.schedulers,
        prefix="g_")
    checkpointer_d = ttools.Checkpointer(
        chkpt, discrim, 
        prefix="d_")

    # Resume from checkpoint, if any
    extras, _ = checkpointer.load_latest()
    checkpointer_d.load_latest()

    if not args.raster_only:
        checkpointer_vect = ttools.Checkpointer(
            chkpt, vect_gen, meta=meta,
            optimizers=interface.optimizers,
            schedulers=interface.schedulers,
            prefix="vect_g_")
        checkpointer_d_vect = ttools.Checkpointer(
            chkpt, vect_discrim, 
            prefix="vect_d_")
        extras, _ = checkpointer_vect.load_latest()
        checkpointer_d_vect.load_latest()

    epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0

    # if meta is not None and meta["model_parameters"] != model_params:
    #     LOG.info("Checkpoint's metaparams differ "
    #              "from CLI, aborting: %s and %s", meta, model_params)

    trainer = ttools.Trainer(interface)

    # Add callbacks
    losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp",
              "gp_vect"]
    training_debug = ["lr"]

    trainer.add_callback(Callback(
        env=env_name, win="samples", port=args.port, frequency=args.freq))
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(
        keys=losses, val_keys=None))
    trainer.add_callback(ttools.callbacks.MultiPlotCallback(
        keys=losses, val_keys=None, env=env_name, port=args.port,
        server=args.server, base_url=args.base_url,
        win="losses", frequency=args.freq))
    trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
        keys=training_debug, smoothing=0, val_keys=None, env=env_name,
        server=args.server, base_url=args.base_url,
        port=args.port))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(
        checkpointer, max_files=2, interval=600, max_epochs=10))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(
        checkpointer_d, max_files=2, interval=600, max_epochs=10))

    if not args.raster_only:
        trainer.add_callback(ttools.callbacks.CheckpointingCallback(
            checkpointer_vect, max_files=2, interval=600, max_epochs=10))
        trainer.add_callback(ttools.callbacks.CheckpointingCallback(
            checkpointer_d_vect, max_files=2, interval=600, max_epochs=10))

    trainer.add_callback(
        ttools.callbacks.LRSchedulerCallback(interface.schedulers))

    # Start training
    trainer.train(dataloader, starting_epoch=epoch,
                  val_dataloader=val_dataloader,
                  num_epochs=args.num_epochs)
Ejemplo n.º 12
0
def main(args):
    # Fix seed
    np.random.seed(0)
    th.manual_seed(0)

    # Parameterization of the dataset (shared between train/val)
    data_args = dict(spp=args.spp,
                     mode=sbmc.TilesDataset.KPCN_MODE
                     if args.kpcn_mode else sbmc.TilesDataset.SBMC_MODE,
                     load_coords=args.load_coords,
                     load_gbuffer=args.load_gbuffer,
                     load_p=args.load_p,
                     load_ld=args.load_ld,
                     load_bt=args.load_bt)

    if args.randomize_spp:
        if args.bs != 1:
            LOG.error(
                "Training with randomized spp is only valid for"
                "batch_size=1, got %d", args.bs)
            raise RuntimeError("Incorrect batch size")
        data = sbmc.MultiSampleCountDataset(args.data, **data_args)
        LOG.info("Training with randomized sample count in [%d, %d]" %
                 (2, args.spp))
    else:
        data = sbmc.TilesDataset(args.data, **data_args)
        LOG.info("Training with a single sample count: %dspp" % args.spp)

    if args.kpcn_mode:
        LOG.info("Model: pixel-based comparison from [Bako2017]")
        model = sbmc.KPCN(data.num_features, ksize=args.ksize)
        model_params = dict(ksize=args.ksize)
    else:
        LOG.info("Model: sample-based [Gharbi2019]")
        model = sbmc.Multisteps(data.num_features,
                                data.num_global_features,
                                ksize=args.ksize,
                                splat=not args.gather,
                                pixel=args.pixel)
        model_params = dict(ksize=args.ksize,
                            gather=args.gather,
                            pixel=args.pixel)

    dataloader = DataLoader(data,
                            batch_size=args.bs,
                            num_workers=args.num_worker_threads,
                            shuffle=True)

    # Validation set uses a constant spp
    val_dataloader = None
    if args.val_data is not None:
        LOG.info("Validation set with %dspp" % args.spp)
        val_data = sbmc.TilesDataset(args.val_data, **data_args)
        val_dataloader = DataLoader(val_data,
                                    batch_size=args.bs,
                                    num_workers=1,
                                    shuffle=False)
    else:
        LOG.info("No validation set provided")

    meta = dict(model_params=model_params,
                kpcn_mode=args.kpcn_mode,
                data_params=data_args)

    LOG.info("Model configuration: {}".format(model_params))

    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta)

    interface = sbmc.SampleBasedDenoiserInterface(model,
                                                  lr=args.lr,
                                                  cuda=args.cuda)

    extras, meta = checkpointer.load_latest()

    trainer = ttools.Trainer(interface)

    # Hook-up some callbacks to the training loop
    log_keys = ["loss", "rmse"]
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(log_keys))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(log_keys,
                                               env=args.env,
                                               port=args.port,
                                               log=True,
                                               frequency=100))
    trainer.add_callback(
        sbmc.DenoisingDisplayCallback(env=args.env,
                                      port=args.port,
                                      win="images"))

    # Launch the training
    LOG.info("Training started, 'Ctrl+C' to abort.")
    trainer.train(dataloader,
                  num_epochs=args.num_epochs,
                  val_dataloader=val_dataloader)
Ejemplo n.º 13
0
def main(args):
    """Entrypoint to the training."""

    # Load model parameters from checkpoint, if any
    meta = ttools.Checkpointer.load_meta(args.checkpoint_dir)
    if meta is None:
        LOG.warning("No checkpoint found at %s, aborting.",
                    args.checkpoint_dir)
        return

    data = demosaicnet.Dataset(args.data,
                               download=False,
                               mode=meta["mode"],
                               subset=demosaicnet.TEST_SUBSET)
    dataloader = DataLoader(data,
                            batch_size=1,
                            num_workers=4,
                            pin_memory=True,
                            shuffle=True)

    if meta["mode"] == demosaicnet.BAYER_MODE:
        model = demosaicnet.BayerDemosaick(depth=meta["depth"],
                                           width=meta["width"],
                                           pretrained=True,
                                           pad=False)
    elif meta["mode"] == demosaicnet.XTRANS_MODE:
        model = demosaicnet.XTransDemosaick(depth=meta["depth"],
                                            width=meta["width"],
                                            pretrained=True,
                                            pad=False)

    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta)
    checkpointer.load_latest()  # Resume from checkpoint, if any.

    # No need for gradients
    for p in model.parameters():
        p.requires_grad = False

    mse_fn = th.nn.MSELoss()
    psnr_fn = PSNR()

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        LOG.info("Using CUDA")

    count = 0
    mse = 0.0
    psnr = 0.0
    for idx, batch in enumerate(dataloader):
        mosaic = batch[0].to(device)
        target = batch[1].to(device)
        output = model(mosaic)

        target = crop_like(target, output)

        output = th.clamp(output, 0, 1)

        psnr_ = psnr_fn(output, target).item()
        mse_ = mse_fn(output, target).item()

        psnr += psnr_
        mse += mse_
        count += 1

        LOG.info("Image %04d, PSNR = %.1f dB, MSE = %.5f", idx, psnr_, mse_)

    mse /= count
    psnr /= count

    LOG.info("-----------------------------------")
    LOG.info("Average, PSNR = %.1f dB, MSE = %.5f", psnr, mse)
Ejemplo n.º 14
0
def main(args):
  log.info("Loading model {}".format(args.checkpoint))
  meta_params = ttools.Checkpointer.load_meta(args.checkpoint)

  spp = meta_params["spp"]
  use_p = meta_params["use_p"]
  use_ld = meta_params["use_ld"]
  use_bt = meta_params["use_bt"]
  # use_coc = meta_params["use_coc"]

  mode = "sample"
  if "DisneyPreprocessor" == meta_params["preprocessor"]:
    mode = "disney_pixel"
  elif "SampleDisneyPreprocessor" == meta_params["preprocessor"]:
    mode = "disney_sample"

  log.info("Rendering at {} spp".format(spp))

  log.info("Setting up dataloader, p:{} bt:{} ld:{}".format(use_p, use_bt, use_ld))
  data = dset.FullImageDataset(args.data, dset.RenderDataset, spp=spp, use_p=use_p, use_ld=use_ld, use_bt=use_bt)
  preprocessor = pre.get(meta_params["preprocessor"])(data)
  xforms = transforms.Compose([dset.ToTensor(), preprocessor])
  data.transform = xforms
  dataloader = DataLoader(data, batch_size=1,
                          shuffle=False, num_workers=0,
                          pin_memory=True)

  model = models.get(preprocessor, meta_params["model_params"])
  model.cuda()
  model.train(False)

  checkpointer = ttools.Checkpointer(args.checkpoint, model, None)
  extras, meta = checkpointer.load_latest()
  log.info("Loading latest checkpoint {}".format("failed" if meta is None else "success"))

  for scene_id, batch in enumerate(dataloader):
    batch_v = make_variable(batch, cuda=True)
    with th.no_grad():
      klist = []
      out_ = model(batch_v, kernel_list=klist)
    lowspp = batch["radiance"]
    target = batch["target_image"]
    out = out_["radiance"]

    cx = 70
    cy = 20
    c = 128

    target = crop_like(target, out)
    lowspp = crop_like(lowspp.squeeze(), out)
    lowspp = lowspp[..., cy:cy+c, cx:cx+c]

    lowspp = lowspp.permute(1, 2, 0, 3)
    chan, h, w, s = lowspp.shape
    lowspp = lowspp.contiguous().view(chan, h, w*s)

    sum_r = []
    sum_w = []
    max_w = []
    maxi = crop_like(klist[-1]["max_w"].unsqueeze(1), out)
    kernels = []
    updated_kernels = []
    for k in klist:
      kernels.append(th.exp(crop_like(k["kernels"], out)-maxi)) 
      updated_kernels.append(th.exp(crop_like(k["updated_kernels"], out)-maxi)) 

    out = out[..., cy:cy+c, cx:cx+c]
    target = target[..., cy:cy+c, cx:cx+c]
    updated_kernels = [k[..., cy:cy+c, cx:cx+c] for k in updated_kernels]
    kernels = [k[..., cy:cy+c, cx:cx+c] for k in kernels]

    u_kernels_im = viz.kernels2im(kernels)
    kmean = u_kernels_im.mean(0)
    kvar = u_kernels_im.std(0)

    n, h, w = u_kernels_im.shape
    u_kernels_im = u_kernels_im.permute(1, 0, 2).contiguous().view(h, w*n)

    fname = os.path.join(args.output, "lowspp.png")
    save(fname, lowspp)
    fname = os.path.join(args.output, "target.png")
    save(fname, target)
    fname = os.path.join(args.output, "output.png")
    save(fname, out)
    fname = os.path.join(args.output, "kernels_gather.png")
    save(fname, u_kernels_im)
    fname = os.path.join(args.output, "kernels_variance.png")
    print(kvar.max())
    save(fname, kvar)
    import ipdb; ipdb.set_trace()
    break
Ejemplo n.º 15
0
def train(args):
    th.manual_seed(0)
    np.random.seed(0)

    dataset = data.FixedLengthQuickDrawDataset(
        args.dataset,
        max_seq_length=args.sequence_length,
        canvas_size=args.raster_resolution)
    dataloader = DataLoader(dataset,
                            batch_size=args.bs,
                            num_workers=args.workers,
                            shuffle=True)

    # val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
    # val_dataloader = DataLoader(
    #     val_dataset, batch_size=8, num_workers=4, shuffle=False)

    val_dataloader = None

    model_params = {
        "zdim": args.zdim,
        "sequence_length": args.sequence_length,
        "image_size": args.raster_resolution,
        # "encoder_dim": args.encoder_dim,
        # "decoder_dim": args.decoder_dim,
    }
    model = SketchVAE(**model_params)
    model.train()

    LOG.info("Model parameters:\n%s", model_params)

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        LOG.info("Using CUDA")

    interface = Interface(model,
                          raster_resolution=args.raster_resolution,
                          lr=args.lr,
                          lr_decay=args.lr_decay,
                          kl_decay=args.kl_decay,
                          kl_weight=args.kl_weight,
                          absolute_coords=args.absolute_coordinates,
                          device=device)

    env_name = "sketch_vae"
    if args.custom_name is not None:
        env_name += "_" + args.custom_name

    if args.absolute_coordinates:
        env_name += "_abs_coords"

    chkpt = os.path.join(OUTPUT, env_name)

    # Resume from checkpoint, if any
    checkpointer = ttools.Checkpointer(chkpt,
                                       model,
                                       meta=model_params,
                                       optimizers=interface.optimizers(),
                                       schedulers=interface.schedulers)
    extras, meta = checkpointer.load_latest()
    epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0

    if meta is not None and meta != model_params:
        LOG.info(
            "Checkpoint's metaparams differ "
            "from CLI, aborting: %s and %s", meta, model_params)

    trainer = ttools.Trainer(interface)

    # Add callbacks
    losses = ["loss", "kl_loss", "vae_im_loss", "sketch_im_loss"]
    training_debug = ["lr", "kl_weight"]
    trainer.add_callback(
        ttools.callbacks.ProgressBarCallback(keys=losses, val_keys=None))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=losses,
                                               val_keys=None,
                                               env=env_name,
                                               port=args.port))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=training_debug,
                                               smoothing=0,
                                               val_keys=None,
                                               env=env_name,
                                               port=args.port))
    trainer.add_callback(
        ttools.callbacks.CheckpointingCallback(checkpointer,
                                               max_files=2,
                                               interval=600,
                                               max_epochs=10))
    trainer.add_callback(
        ttools.callbacks.LRSchedulerCallback(interface.schedulers))

    trainer.add_callback(
        SketchVAECallback(env=env_name,
                          win="samples",
                          port=args.port,
                          frequency=args.freq))

    # Start training
    trainer.train(dataloader,
                  starting_epoch=epoch,
                  val_dataloader=val_dataloader,
                  num_epochs=args.num_epochs)
Ejemplo n.º 16
0
def run(args):
    th.manual_seed(0)
    np.random.seed(0)

    meta = ttools.Checkpointer.load_meta(args.model, "vect_g_")

    if meta is None:
        LOG.warning("Could not load metadata at %s, aborting.", args.model)
        return

    LOG.info("Loaded model %s with metadata:\n %s", args.model, meta)

    if args.output_dir is None:
        outdir = os.path.join(args.model, "eval")
    else:
        outdir = args.output_dir
    os.makedirs(outdir, exist_ok=True)

    model_params = meta["model_params"]
    if args.imsize is not None:
        LOG.info("Overriding output image size to: %dx%d", args.imsize,
                 args.imsize)
        old_size = model_params["imsize"]
        scale = args.imsize * 1.0 / old_size
        model_params["imsize"] = args.imsize
        model_params["stroke_width"] = [
            w * scale for w in model_params["stroke_width"]
        ]
        LOG.info("Overriding width to: %s", model_params["stroke_width"])

    # task = meta["task"]
    generator = meta["generator"]
    if generator == "fc":
        model = models.VectorGenerator(**model_params)
    elif generator == "bezier_fc":
        model = models.BezierVectorGenerator(**model_params)
    elif generator in ["rnn"]:
        model = models.RNNVectorGenerator(**model_params)
    elif generator in ["chain_rnn"]:
        model = models.ChainRNNVectorGenerator(**model_params)
    else:
        raise NotImplementedError()
    model.eval()

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"

    model.to(device)

    checkpointer = ttools.Checkpointer(args.model,
                                       model,
                                       meta=meta,
                                       prefix="vect_g_")
    checkpointer.load_latest()

    LOG.info("Computing latent space interpolation")
    for i in range(args.nsamples):
        z0 = model.sample_z(1)
        z1 = model.sample_z(1)

        # interpolation
        alpha = th.linspace(0, 1, args.nsteps).view(args.nsteps, 1).to(device)
        alpha_video = th.linspace(0, 1, args.nframes).view(args.nframes, 1)
        alpha_video = alpha_video.to(device)

        length = [args.nsteps, args.nframes]
        for idx, a in enumerate([alpha, alpha_video]):
            _z0 = z0.repeat(length[idx], 1).to(device)
            _z1 = z1.repeat(length[idx], 1).to(device)
            batch = _z0 * (1 - a) + _z1 * a
            out = model(batch)
            if idx == 0:  # image viz
                n, c, h, w = out.shape
                out = out.permute(1, 2, 0, 3)
                out = out.contiguous().view(1, c, h, w * n)
                out = postprocess(out, invert=args.invert)
                imsave(out,
                       os.path.join(outdir, "latent_interp", "%03d.png" % i))

                scenes = model.get_vector(batch)
                for scn_idx, scn in enumerate(scenes):
                    save_scene(
                        scn,
                        os.path.join(outdir, "latent_interp_svg", "%03d" % i,
                                     "%03d.svg" % scn_idx))
            else:  # video viz
                anim_root = os.path.join(outdir, "latent_interp_video",
                                         "%03d" % i)
                LOG.info("Rendering animation %d", i)
                for frame_idx, frame in enumerate(out):
                    LOG.info("frame %d", frame_idx)
                    frame = frame.unsqueeze(0)
                    frame = postprocess(frame, invert=args.invert)
                    imsave(
                        frame,
                        os.path.join(anim_root, "frame%04d.png" % frame_idx))
                call([
                    "ffmpeg", "-framerate", "30", "-i",
                    os.path.join(anim_root, "frame%04d.png"), "-vb", "20M",
                    os.path.join(outdir, "latent_interp_video", "%03d.mp4" % i)
                ])
        LOG.info("  saved %d", i)

    LOG.info("Sampling latent space")

    for i in range(args.nsamples):
        n = 8
        bs = n * n
        z = model.sample_z(bs).to(device)
        out = model(z)
        _, c, h, w = out.shape
        out = out.view(n, n, c, h, w).permute(2, 0, 3, 1, 4)
        out = out.contiguous().view(1, c, h * n, w * n)
        out = postprocess(out)
        imsave(out, os.path.join(outdir, "samples_%03d.png" % i))
        LOG.info("  saved %d", i)

    LOG.info("output images saved to %s", outdir)
Ejemplo n.º 17
0
def main(args):
    """Entrypoint to the training."""

    # Load model parameters from checkpoint, if any
    meta = ttools.Checkpointer.load_meta(args.checkpoint_dir)
    if meta is None:
        LOG.info("No metadata or checkpoint, "
                 "parsing model parameters from command line.")
        meta = {
            "depth": args.depth,
            "width": args.width,
            "mode": args.mode,
        }

    data = demosaicnet.Dataset(args.data,
                               download=False,
                               mode=meta["mode"],
                               subset=demosaicnet.TRAIN_SUBSET)
    dataloader = DataLoader(data,
                            batch_size=args.bs,
                            num_workers=args.num_worker_threads,
                            pin_memory=True,
                            shuffle=True)

    val_dataloader = None
    if args.val_data:
        val_data = demosaicnet.Dataset(args.data,
                                       download=False,
                                       mode=meta["mode"],
                                       subset=demosaicnet.VAL_SUBSET)
        val_dataloader = DataLoader(val_data,
                                    batch_size=args.bs,
                                    num_workers=1,
                                    pin_memory=True,
                                    shuffle=False)

    if meta["mode"] == demosaicnet.BAYER_MODE:
        model = demosaicnet.BayerDemosaick(depth=meta["depth"],
                                           width=meta["width"],
                                           pretrained=True,
                                           pad=False)
    elif meta["mode"] == demosaicnet.XTRANS_MODE:
        model = demosaicnet.XTransDemosaick(depth=meta["depth"],
                                            width=meta["width"],
                                            pretrained=True,
                                            pad=False)
    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta)

    interface = DemosaicnetInterface(model, lr=args.lr, cuda=args.cuda)

    checkpointer.load_latest()  # Resume from checkpoint, if any.

    trainer = ttools.Trainer(interface)

    keys = ["loss", "psnr"]
    val_keys = ["psnr"]

    trainer.add_callback(
        ttools.callbacks.ProgressBarCallback(keys=keys, val_keys=val_keys))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=keys,
                                               val_keys=val_keys,
                                               server=args.server,
                                               env=args.env,
                                               port=args.port))
    trainer.add_callback(
        ImageCallback(server=args.server,
                      env=args.env,
                      win="images",
                      port=args.port))
    trainer.add_callback(
        ttools.callbacks.CheckpointingCallback(checkpointer,
                                               max_files=8,
                                               interval=3600,
                                               max_epochs=10))

    if args.cuda:
        LOG.info("Training with CUDA enabled")
    else:
        LOG.info("Training on CPU")

    trainer.train(dataloader,
                  num_epochs=args.num_epochs,
                  val_dataloader=val_dataloader)