Example #1
0
def main(args):
    data = datasets.ShapenetDataset(args.data, args.canvas_size)
    dataloader = DataLoader(data, batch_size=args.bs, num_workers=args.num_worker_threads,
                            worker_init_fn=_worker_init_fn, shuffle=True, drop_last=True)
    LOG.info(data)

    val_data = datasets.ShapenetDataset(args.data, args.canvas_size, val=True)
    val_dataloader = DataLoader(val_data)

    model = PrimsModel(output_dim=(11 if args.rounded else 10)*args.n_primitives)

    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model)
    checkpointer.load_latest()

    interface = VectorizerInterface(model, args.lr, args.n_primitives, args.canvas_size, args.w_surface,
                                    args.w_alignment, args.csg, args.rounded, cuda=args.cuda)

    keys = ['loss', 'surfaceloss', 'alignmentloss']

    writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries',
                                        datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')), flush_secs=1)
    val_writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries',
                                            datetime.datetime.now().strftime('val-%m%d%y-%H%M%S')), flush_secs=1)

    trainer = ttools.Trainer(interface)
    trainer.add_callback(ttools.callbacks.TensorBoardLoggingCallback(keys=keys, writer=writer,
                                                                     val_writer=val_writer, frequency=5))
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(keys=keys))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer, max_files=1))
    trainer.train(dataloader, num_epochs=args.num_epochs, val_dataloader=val_dataloader)
Example #2
0
def train(args):
    """Train a MNIST classifier."""

    # Setup train and val data
    _xform = xforms.Compose([xforms.Resize([32, 32]), xforms.ToTensor()])
    data = MNIST("data/mnist", train=True, download=True, transform=_xform)

    # Initialize asynchronous dataloaders
    loader = DataLoader(data, batch_size=args.bs, num_workers=2)

    # Instantiate the models
    gen = Generator()
    discrim = Discriminator()

    gen.apply(weights_init_normal)
    discrim.apply(weights_init_normal)

    # Checkpointer to save/recall model parameters
    checkpointer_gen = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=gen, prefix="gen_")
    checkpointer_discrim = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=discrim, prefix="discrim_")

    # resume from a previous checkpoint, if any
    checkpointer_gen.load_latest()
    checkpointer_discrim.load_latest()

    # Setup a training interface for the model
    interface = MNISTInterface(gen, discrim, lr=args.lr)

    # Create a training looper with the interface we defined
    trainer = ttools.Trainer(interface)

    # Adds several callbacks, that will be called by the trainer --------------
    # A periodic checkpointing operation
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_gen))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_discrim))
    # A simple progress bar
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(
        keys=["loss_g", "loss_d", "loss"]))
    # A volatile logging using visdom
    trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
        keys=["loss_g", "loss_d", "loss"],
        port=8080, env="mnist_demo"))
    # Image
    trainer.add_callback(VisdomImageCallback(port=8080, env="mnist_demo"))
    # -------------------------------------------------------------------------

    # Start the training
    LOG.info("Training started, press Ctrl-C to interrupt.")
    trainer.train(loader, num_epochs=args.epochs)
Example #3
0
def train(args):
    """Train a MNIST classifier."""

    # Setup train and val data
    _xform = xforms.Compose([xforms.Resize([32, 32]), xforms.ToTensor()])
    data = MNIST("data/mnist", train=True, download=True, transform=_xform)
    val_data = MNIST("data/mnist", train=False, transform=_xform)

    # Initialize asynchronous dataloaders
    loader = DataLoader(data, batch_size=args.bs, num_workers=2)
    val_loader = DataLoader(val_data, batch_size=16, num_workers=1)

    # Instantiate a model
    model = MNISTClassifier()

    # Checkpointer to save/recall model parameters
    checkpointer = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=model, prefix="classifier_")

    # resume from a previous checkpoint, if any
    checkpointer.load_latest()

    # Setup a training interface for the model
    if th.cuda.is_available():
        device = th.device("cuda")
    else:
        device = th.device("cpu")
    interface = MNISTInterface(model, device, lr=args.lr)

    # Create a training looper with the interface we defined
    trainer = ttools.Trainer(interface)

    # Adds several callbacks, that will be called by the trainer --------------
    # A periodic checkpointing operation
    LOG.info("This demo uses a Visdom to display the loss and accuracy, make sure you have a visdom server running! ('make visdom_server')")
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer))
    # A simple progress bar
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(
        keys=["loss", "accuracy"], val_keys=["loss", "accuracy"]))
    # A volatile logging using visdom
    trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
        keys=["loss", "accuracy"], val_keys=["loss", "accuracy"],
        port=8080, env="mnist_demo"))
    # -------------------------------------------------------------------------

    # Start the training
    LOG.info("Training started, press Ctrl-C to interrupt.")
    trainer.train(loader, num_epochs=args.epochs, val_dataloader=val_loader)
def main(args):
    data = datasets.FontsDataset(args.data, args.chamfer, args.n_samples_per_curve)
    dataloader = DataLoader(data, batch_size=args.bs, num_workers=args.num_worker_threads,
                            worker_init_fn=_worker_init_fn, shuffle=True, drop_last=True)
    LOG.info(data)

    val_data = datasets.FontsDataset(args.data, args.chamfer, args.n_samples_per_curve, val=True)
    val_dataloader = DataLoader(val_data)

    model = CurvesModel(n_curves=sum(templates.topology))

    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model)
    checkpointer.load_latest()

    interface = VectorizerInterface(model, args.simple_templates, args.lr, args.max_stroke, args.canvas_size,
                                    args.chamfer, args.n_samples_per_curve, args.w_surface, args.w_template,
                                    args.w_alignment, cuda=args.cuda)

    keys = ['loss', 'chamferloss', 'templateloss'] if args.chamfer \
        else ['loss', 'surfaceloss', 'alignmentloss', 'templateloss']

    writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries',
                                        datetime.datetime.now().strftime('train-%m%d%y-%H%M%S')), flush_secs=1)
    val_writer = SummaryWriter(os.path.join(args.checkpoint_dir, 'summaries',
                                            datetime.datetime.now().strftime('val-%m%d%y-%H%M%S')), flush_secs=1)

    trainer = ttools.Trainer(interface)
    trainer.add_callback(ttools.callbacks.TensorBoardLoggingCallback(keys=keys, writer=writer,
                                                                     val_writer=val_writer, frequency=5))
    trainer.add_callback(callbacks.InputImageCallback(writer=writer, val_writer=val_writer, frequency=100))
    trainer.add_callback(callbacks.CurvesCallback(writer=writer, val_writer=val_writer, frequency=100))
    if not args.chamfer:
        trainer.add_callback(callbacks.RenderingCallback(writer=writer, val_writer=val_writer, frequency=100))
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(keys=keys))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer, max_files=1))
    trainer.train(dataloader, num_epochs=args.num_epochs, val_dataloader=val_dataloader)
Example #5
0
def train(args):
    th.manual_seed(0)
    np.random.seed(0)

    dataset = data.QuickDrawDataset(args.dataset)
    dataloader = DataLoader(dataset,
                            batch_size=args.bs,
                            num_workers=4,
                            shuffle=True,
                            pin_memory=False)

    val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
    val_dataloader = DataLoader(val_dataset,
                                batch_size=8,
                                num_workers=4,
                                shuffle=False,
                                pin_memory=False)

    model_params = {
        "zdim": args.zdim,
        "num_gaussians": args.num_gaussians,
        "encoder_dim": args.encoder_dim,
        "decoder_dim": args.decoder_dim,
    }
    model = SketchRNN(**model_params)
    model.train()

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        LOG.info("Using CUDA")

    interface = Interface(model,
                          lr=args.lr,
                          lr_decay=args.lr_decay,
                          kl_decay=args.kl_decay,
                          kl_weight=args.kl_weight,
                          sampling_temperature=args.sampling_temperature,
                          device=device)

    chkpt = OUTPUT_BASELINE
    env_name = "sketch_rnn"

    # Resume from checkpoint, if any
    checkpointer = ttools.Checkpointer(chkpt,
                                       model,
                                       meta=model_params,
                                       optimizers=interface.optimizers(),
                                       schedulers=interface.schedulers)
    extras, meta = checkpointer.load_latest()
    epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0

    if meta is not None and meta != model_params:
        LOG.info(
            "Checkpoint's metaparams differ "
            "from CLI, aborting: %s and %s", meta, model_params)

    trainer = ttools.Trainer(interface)

    # Add callbacks
    losses = ["loss", "kl_loss", "recons_loss"]
    training_debug = ["lr", "kl_weight"]
    trainer.add_callback(
        ttools.callbacks.ProgressBarCallback(keys=losses, val_keys=None))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=losses,
                                               val_keys=None,
                                               env=env_name,
                                               port=args.port))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=training_debug,
                                               smoothing=0,
                                               val_keys=None,
                                               env=env_name,
                                               port=args.port))
    trainer.add_callback(
        ttools.callbacks.CheckpointingCallback(checkpointer,
                                               max_files=2,
                                               interval=600,
                                               max_epochs=10))
    trainer.add_callback(
        ttools.callbacks.LRSchedulerCallback(interface.schedulers))

    trainer.add_callback(
        SketchRNNCallback(env=env_name,
                          win="samples",
                          port=args.port,
                          frequency=args.freq))

    # Start training
    trainer.train(dataloader,
                  starting_epoch=epoch,
                  val_dataloader=val_dataloader,
                  num_epochs=args.num_epochs)
Example #6
0
def train(args):
    th.manual_seed(0)
    np.random.seed(0)

    pydiffvg.set_use_gpu(args.cuda)

    # Initialize datasets
    imsize = 28
    dataset = Dataset(args.data_dir, imsize)
    dataloader = DataLoader(dataset, batch_size=args.bs,
                            num_workers=4, shuffle=True)

    if args.generator in ["vae", "ae"]:
        LOG.info("Vector config:\n  samples %d\n"
                 "  paths: %d\n  segments: %d\n"
                 "  zdim: %d\n"
                 "  conditional: %d\n"
                 "  fc: %d\n",
                 args.samples, args.paths, args.segments,
                 args.zdim, args.conditional, args.fc)

    model_params = dict(samples=args.samples, paths=args.paths,
                        segments=args.segments, conditional=args.conditional,
                        zdim=args.zdim, fc=args.fc)

    if args.generator == "vae":
        model = VectorMNISTVAE(variational=True, **model_params)
        chkpt = VAE_OUTPUT
        name = "mnist_vae"
    elif args.generator == "ae":
        model = VectorMNISTVAE(variational=False, **model_params)
        chkpt = AE_OUTPUT
        name = "mnist_ae"
    else:
        raise ValueError("unknown generator")

    if args.conditional:
        name += "_conditional"
        chkpt += "_conditional"

    if args.fc:
        name += "_fc"
        chkpt += "_fc"

    # Resume from checkpoint, if any
    checkpointer = ttools.Checkpointer(
        chkpt, model, meta=model_params, prefix="g_")
    extras, meta = checkpointer.load_latest()

    if meta is not None and meta != model_params:
        LOG.info("Checkpoint's metaparams differ from CLI, aborting: %s and %s",
                 meta, model_params)

    # Hook interface
    if args.generator in ["vae", "ae"]:
        variational = args.generator == "vae"
        if variational:
            LOG.info("Using a VAE")
        else:
            LOG.info("Using an AE")
        interface = VAEInterface(model, lr=args.lr, cuda=args.cuda,
                                 variational=variational, w_kld=args.kld_weight)

    trainer = ttools.Trainer(interface)

    # Add callbacks
    keys = ["loss_g", "loss_d"]
    if args.generator == "vae":
        keys = ["kld", "data_loss", "loss"]
    elif args.generator == "ae":
        keys = ["data_loss", "loss"]
    port = 8097
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(
        keys=keys, val_keys=keys))
    trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
        keys=keys, val_keys=keys, env=name, port=port))
    trainer.add_callback(MNISTCallback(
        env=name, win="samples", port=port, frequency=args.freq))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(
        checkpointer, max_files=2, interval=600, max_epochs=50))

    # Start training
    trainer.train(dataloader, num_epochs=args.num_epochs)
Example #7
0
def main(args):
    # Fix seed
    np.random.seed(0)
    th.manual_seed(0)

    # Parameterization of the dataset (shared between train/val)
    data_args = dict(spp=args.spp,
                     mode=sbmc.TilesDataset.KPCN_MODE
                     if args.kpcn_mode else sbmc.TilesDataset.SBMC_MODE,
                     load_coords=args.load_coords,
                     load_gbuffer=args.load_gbuffer,
                     load_p=args.load_p,
                     load_ld=args.load_ld,
                     load_bt=args.load_bt)

    if args.randomize_spp:
        if args.bs != 1:
            LOG.error(
                "Training with randomized spp is only valid for"
                "batch_size=1, got %d", args.bs)
            raise RuntimeError("Incorrect batch size")
        data = sbmc.MultiSampleCountDataset(args.data, **data_args)
        LOG.info("Training with randomized sample count in [%d, %d]" %
                 (2, args.spp))
    else:
        data = sbmc.TilesDataset(args.data, **data_args)
        LOG.info("Training with a single sample count: %dspp" % args.spp)

    if args.kpcn_mode:
        LOG.info("Model: pixel-based comparison from [Bako2017]")
        model = sbmc.KPCN(data.num_features, ksize=args.ksize)
        model_params = dict(ksize=args.ksize)
    else:
        LOG.info("Model: sample-based [Gharbi2019]")
        model = sbmc.Multisteps(data.num_features,
                                data.num_global_features,
                                ksize=args.ksize,
                                splat=not args.gather,
                                pixel=args.pixel)
        model_params = dict(ksize=args.ksize,
                            gather=args.gather,
                            pixel=args.pixel)

    dataloader = DataLoader(data,
                            batch_size=args.bs,
                            num_workers=args.num_worker_threads,
                            shuffle=True)

    # Validation set uses a constant spp
    val_dataloader = None
    if args.val_data is not None:
        LOG.info("Validation set with %dspp" % args.spp)
        val_data = sbmc.TilesDataset(args.val_data, **data_args)
        val_dataloader = DataLoader(val_data,
                                    batch_size=args.bs,
                                    num_workers=1,
                                    shuffle=False)
    else:
        LOG.info("No validation set provided")

    meta = dict(model_params=model_params,
                kpcn_mode=args.kpcn_mode,
                data_params=data_args)

    LOG.info("Model configuration: {}".format(model_params))

    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta)

    interface = sbmc.SampleBasedDenoiserInterface(model,
                                                  lr=args.lr,
                                                  cuda=args.cuda)

    extras, meta = checkpointer.load_latest()

    trainer = ttools.Trainer(interface)

    # Hook-up some callbacks to the training loop
    log_keys = ["loss", "rmse"]
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(log_keys))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(log_keys,
                                               env=args.env,
                                               port=args.port,
                                               log=True,
                                               frequency=100))
    trainer.add_callback(
        sbmc.DenoisingDisplayCallback(env=args.env,
                                      port=args.port,
                                      win="images"))

    # Launch the training
    LOG.info("Training started, 'Ctrl+C' to abort.")
    trainer.train(dataloader,
                  num_epochs=args.num_epochs,
                  val_dataloader=val_dataloader)
Example #8
0
def train(args):
    th.manual_seed(0)
    np.random.seed(0)

    dataset = data.FixedLengthQuickDrawDataset(
        args.dataset,
        max_seq_length=args.sequence_length,
        canvas_size=args.raster_resolution)
    dataloader = DataLoader(dataset,
                            batch_size=args.bs,
                            num_workers=args.workers,
                            shuffle=True)

    # val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
    # val_dataloader = DataLoader(
    #     val_dataset, batch_size=8, num_workers=4, shuffle=False)

    val_dataloader = None

    model_params = {
        "zdim": args.zdim,
        "sequence_length": args.sequence_length,
        "image_size": args.raster_resolution,
        # "encoder_dim": args.encoder_dim,
        # "decoder_dim": args.decoder_dim,
    }
    model = SketchVAE(**model_params)
    model.train()

    LOG.info("Model parameters:\n%s", model_params)

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        LOG.info("Using CUDA")

    interface = Interface(model,
                          raster_resolution=args.raster_resolution,
                          lr=args.lr,
                          lr_decay=args.lr_decay,
                          kl_decay=args.kl_decay,
                          kl_weight=args.kl_weight,
                          absolute_coords=args.absolute_coordinates,
                          device=device)

    env_name = "sketch_vae"
    if args.custom_name is not None:
        env_name += "_" + args.custom_name

    if args.absolute_coordinates:
        env_name += "_abs_coords"

    chkpt = os.path.join(OUTPUT, env_name)

    # Resume from checkpoint, if any
    checkpointer = ttools.Checkpointer(chkpt,
                                       model,
                                       meta=model_params,
                                       optimizers=interface.optimizers(),
                                       schedulers=interface.schedulers)
    extras, meta = checkpointer.load_latest()
    epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0

    if meta is not None and meta != model_params:
        LOG.info(
            "Checkpoint's metaparams differ "
            "from CLI, aborting: %s and %s", meta, model_params)

    trainer = ttools.Trainer(interface)

    # Add callbacks
    losses = ["loss", "kl_loss", "vae_im_loss", "sketch_im_loss"]
    training_debug = ["lr", "kl_weight"]
    trainer.add_callback(
        ttools.callbacks.ProgressBarCallback(keys=losses, val_keys=None))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=losses,
                                               val_keys=None,
                                               env=env_name,
                                               port=args.port))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=training_debug,
                                               smoothing=0,
                                               val_keys=None,
                                               env=env_name,
                                               port=args.port))
    trainer.add_callback(
        ttools.callbacks.CheckpointingCallback(checkpointer,
                                               max_files=2,
                                               interval=600,
                                               max_epochs=10))
    trainer.add_callback(
        ttools.callbacks.LRSchedulerCallback(interface.schedulers))

    trainer.add_callback(
        SketchVAECallback(env=env_name,
                          win="samples",
                          port=args.port,
                          frequency=args.freq))

    # Start training
    trainer.train(dataloader,
                  starting_epoch=epoch,
                  val_dataloader=val_dataloader,
                  num_epochs=args.num_epochs)
Example #9
0
def train(args):
    th.manual_seed(0)
    np.random.seed(0)

    color_output = False
    if args.task == "mnist":
        dataset = data.MNISTDataset(args.raster_resolution, train=True)
    elif args.task == "quickdraw":
        dataset = data.QuickDrawImageDataset(
            args.raster_resolution, train=True)
    else:
        raise NotImplementedError()

    dataloader = DataLoader(
        dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)

    val_dataloader = None

    model_params = {
        "zdim": args.zdim,
        "num_strokes": args.num_strokes,
        "imsize": args.raster_resolution,
        "stroke_width": args.stroke_width,
        "color_output": color_output,
    }
    gen = models.Generator(**model_params)
    gen.train()

    discrim = models.Discriminator(color_output=color_output)
    discrim.train()

    if args.raster_only:
        vect_gen = None
        vect_discrim = None
    else:
        if args.generator == "fc":
            vect_gen = models.VectorGenerator(**model_params)
        elif args.generator == "bezier_fc":
            vect_gen = models.BezierVectorGenerator(**model_params)
        elif args.generator in ["rnn"]:
            vect_gen = models.RNNVectorGenerator(**model_params)
        elif args.generator in ["chain_rnn"]:
            vect_gen = models.ChainRNNVectorGenerator(**model_params)
        else:
            raise NotImplementedError()
        vect_gen.train()

        vect_discrim = models.Discriminator(color_output=color_output)
        vect_discrim.train()

    LOG.info("Model parameters:\n%s", model_params)

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        LOG.info("Using CUDA")

    interface = Interface(gen, vect_gen, discrim, vect_discrim,
                          raster_resolution=args.raster_resolution, lr=args.lr,
                          wgan_gp=args.wgan_gp,
                          lr_decay=args.lr_decay, device=device)

    env_name = args.task + "_gan"

    if args.raster_only:
        env_name += "_raster"
    else:
        env_name += "_vector"

    env_name += "_" + args.generator

    if args.wgan_gp:
        env_name += "_wgan"

    chkpt = os.path.join(OUTPUT, env_name)

    meta = {
        "model_params": model_params,
        "task": args.task,
        "generator": args.generator,
    }
    checkpointer = ttools.Checkpointer(
        chkpt, gen, meta=meta,
        optimizers=interface.optimizers,
        schedulers=interface.schedulers,
        prefix="g_")
    checkpointer_d = ttools.Checkpointer(
        chkpt, discrim, 
        prefix="d_")

    # Resume from checkpoint, if any
    extras, _ = checkpointer.load_latest()
    checkpointer_d.load_latest()

    if not args.raster_only:
        checkpointer_vect = ttools.Checkpointer(
            chkpt, vect_gen, meta=meta,
            optimizers=interface.optimizers,
            schedulers=interface.schedulers,
            prefix="vect_g_")
        checkpointer_d_vect = ttools.Checkpointer(
            chkpt, vect_discrim, 
            prefix="vect_d_")
        extras, _ = checkpointer_vect.load_latest()
        checkpointer_d_vect.load_latest()

    epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0

    # if meta is not None and meta["model_parameters"] != model_params:
    #     LOG.info("Checkpoint's metaparams differ "
    #              "from CLI, aborting: %s and %s", meta, model_params)

    trainer = ttools.Trainer(interface)

    # Add callbacks
    losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp",
              "gp_vect"]
    training_debug = ["lr"]

    trainer.add_callback(Callback(
        env=env_name, win="samples", port=args.port, frequency=args.freq))
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(
        keys=losses, val_keys=None))
    trainer.add_callback(ttools.callbacks.MultiPlotCallback(
        keys=losses, val_keys=None, env=env_name, port=args.port,
        server=args.server, base_url=args.base_url,
        win="losses", frequency=args.freq))
    trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
        keys=training_debug, smoothing=0, val_keys=None, env=env_name,
        server=args.server, base_url=args.base_url,
        port=args.port))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(
        checkpointer, max_files=2, interval=600, max_epochs=10))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(
        checkpointer_d, max_files=2, interval=600, max_epochs=10))

    if not args.raster_only:
        trainer.add_callback(ttools.callbacks.CheckpointingCallback(
            checkpointer_vect, max_files=2, interval=600, max_epochs=10))
        trainer.add_callback(ttools.callbacks.CheckpointingCallback(
            checkpointer_d_vect, max_files=2, interval=600, max_epochs=10))

    trainer.add_callback(
        ttools.callbacks.LRSchedulerCallback(interface.schedulers))

    # Start training
    trainer.train(dataloader, starting_epoch=epoch,
                  val_dataloader=val_dataloader,
                  num_epochs=args.num_epochs)
Example #10
0
def main(args):
    """Entrypoint to the training."""

    # Load model parameters from checkpoint, if any
    meta = ttools.Checkpointer.load_meta(args.checkpoint_dir)
    if meta is None:
        LOG.info("No metadata or checkpoint, "
                 "parsing model parameters from command line.")
        meta = {
            "depth": args.depth,
            "width": args.width,
            "mode": args.mode,
        }

    data = demosaicnet.Dataset(args.data,
                               download=False,
                               mode=meta["mode"],
                               subset=demosaicnet.TRAIN_SUBSET)
    dataloader = DataLoader(data,
                            batch_size=args.bs,
                            num_workers=args.num_worker_threads,
                            pin_memory=True,
                            shuffle=True)

    val_dataloader = None
    if args.val_data:
        val_data = demosaicnet.Dataset(args.data,
                                       download=False,
                                       mode=meta["mode"],
                                       subset=demosaicnet.VAL_SUBSET)
        val_dataloader = DataLoader(val_data,
                                    batch_size=args.bs,
                                    num_workers=1,
                                    pin_memory=True,
                                    shuffle=False)

    if meta["mode"] == demosaicnet.BAYER_MODE:
        model = demosaicnet.BayerDemosaick(depth=meta["depth"],
                                           width=meta["width"],
                                           pretrained=True,
                                           pad=False)
    elif meta["mode"] == demosaicnet.XTRANS_MODE:
        model = demosaicnet.XTransDemosaick(depth=meta["depth"],
                                            width=meta["width"],
                                            pretrained=True,
                                            pad=False)
    checkpointer = ttools.Checkpointer(args.checkpoint_dir, model, meta=meta)

    interface = DemosaicnetInterface(model, lr=args.lr, cuda=args.cuda)

    checkpointer.load_latest()  # Resume from checkpoint, if any.

    trainer = ttools.Trainer(interface)

    keys = ["loss", "psnr"]
    val_keys = ["psnr"]

    trainer.add_callback(
        ttools.callbacks.ProgressBarCallback(keys=keys, val_keys=val_keys))
    trainer.add_callback(
        ttools.callbacks.VisdomLoggingCallback(keys=keys,
                                               val_keys=val_keys,
                                               server=args.server,
                                               env=args.env,
                                               port=args.port))
    trainer.add_callback(
        ImageCallback(server=args.server,
                      env=args.env,
                      win="images",
                      port=args.port))
    trainer.add_callback(
        ttools.callbacks.CheckpointingCallback(checkpointer,
                                               max_files=8,
                                               interval=3600,
                                               max_epochs=10))

    if args.cuda:
        LOG.info("Training with CUDA enabled")
    else:
        LOG.info("Training on CPU")

    trainer.train(dataloader,
                  num_epochs=args.num_epochs,
                  val_dataloader=val_dataloader)