예제 #1
0
def configure_lr_scheduler(args, optimizer):
    lr_scheduler = None

    with logger.LoggingBlock("Learning Rate Scheduler", emph=True):
        logging.info("class: %s" % args.lr_scheduler)

        if args.lr_scheduler is not None:

            # ----------------------------------------------
            # Figure out lr_scheduler arguments
            # ----------------------------------------------
            kwargs = tools.kwargs_from_args(args, "lr_scheduler")

            # -------------------------------------------
            # Print arguments
            # -------------------------------------------
            for param, default in sorted(kwargs.items()):
                logging.info("%s: %s" % (param, default))

            # -------------------------------------------
            # Add optimizer
            # -------------------------------------------
            kwargs["optimizer"] = optimizer

            # -------------------------------------------
            # Create lr_scheduler instance
            # -------------------------------------------
            lr_scheduler = tools.instance_from_kwargs(args.lr_scheduler_class,
                                                      kwargs)

    return lr_scheduler
예제 #2
0
def setup_logging_and_parse_arguments(blocktitle, yaml_conf=None):
    # ----------------------------------------------------------------------------
    # Get parse commandline and default arguments
    # ----------------------------------------------------------------------------
    args, defaults = _parse_arguments(yaml_conf=yaml_conf)

    # ----------------------------------------------------------------------------
    # Setup logbook before everything else
    # ----------------------------------------------------------------------------
    logger.configure_logging(os.path.join(args.save, 'logbook.txt'))

    # ----------------------------------------------------------------------------
    # Write arguments to file, as txt
    # ----------------------------------------------------------------------------
    tools.write_dictionary_to_file(
        sorted(vars(args).items()),
        filename=os.path.join(args.save, 'args.txt'))

    # ----------------------------------------------------------------------------
    # Log arguments
    # ----------------------------------------------------------------------------
    with logger.LoggingBlock(blocktitle, emph=True):
        for argument, value in sorted(vars(args).items()):
            reset = colorama.Style.RESET_ALL
            color = reset if (argument in defaults and value == defaults[argument]) else colorama.Fore.CYAN
            logging.info('{}{}: {}{}'.format(color, argument, value, reset))

    # ----------------------------------------------------------------------------
    # Postprocess
    # ----------------------------------------------------------------------------
    args = postprocess_args(args)

    return args
예제 #3
0
 def _log_statistics(dataset, prefix, name):
     with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)):
         example_dict = dataset[0]  # get sizes from first dataset example
         for key, value in sorted(example_dict.items()):
             if key in ["index", "basename"]:  # no need to display these
                 continue
             if isinstance(value, str):
                 logging.info("{}: {}".format(key, value))
             else:
                 logging.info("%s: %s" % (key, _sizes_to_str(value)))
         logging.info("num_examples: %i" % len(dataset))
예제 #4
0
def configure_random_seed(args):
    with logger.LoggingBlock("Random Seeds", emph=True):
        # python
        seed = args.seed
        random.seed(seed)
        logging.info("Python seed: %i" % seed)
        # numpy
        seed += 1
        np.random.seed(seed)
        logging.info("Numpy seed: %i" % seed)
        # torch
        seed += 1
        torch.manual_seed(seed)
        logging.info("Torch CPU seed: %i" % seed)
        # torch cuda
        seed += 1
        torch.cuda.manual_seed(seed)
        logging.info("Torch CUDA seed: %i" % seed)
예제 #5
0
def load_checkpoint_saver(args, model_and_loss):
    with logger.LoggingBlock("Checkpoint", emph=True):
        checkpoint_saver = CheckpointSaver()
        checkpoint_stats = None

        if args.checkpoint is None:
            logging.info("No checkpoint given.")
            logging.info("Starting from scratch with random initialization.")

        elif os.path.isfile(args.checkpoint):
            logging.info("Loading checkpoint in %s" % args.checkpoint)
            checkpoint_stats, filename = checkpoint_saver.restore(
                filename=args.checkpoint,
                model_and_loss=model_and_loss,
                include_params=args.checkpoint_include_params,
                exclude_params=args.checkpoint_exclude_params)

        elif os.path.isdir(args.checkpoint):
            if args.checkpoint_mode in ["resume_from_best"]:
                logging.info("Loading best checkpoint in %s" % args.checkpoint)
                checkpoint_stats, filename = checkpoint_saver.restore_best(
                    directory=args.checkpoint,
                    model_and_loss=model_and_loss,
                    include_params=args.checkpoint_include_params,
                    exclude_params=args.checkpoint_exclude_params)

            elif args.checkpoint_mode in ["resume_from_latest"]:
                logging.info("Loading latest checkpoint in %s" %
                             args.checkpoint)
                checkpoint_stats, filename = checkpoint_saver.restore_latest(
                    directory=args.checkpoint,
                    model_and_loss=model_and_loss,
                    include_params=args.checkpoint_include_params,
                    exclude_params=args.checkpoint_exclude_params)
            else:
                logging.info("Unknown checkpoint_restore '%s' given!" %
                             args.checkpoint_restore)
                quit()
        else:
            logging.info("Could not find checkpoint file or directory '%s'" %
                         args.checkpoint)
            quit()

    return checkpoint_saver, checkpoint_stats
예제 #6
0
def configure_runtime_augmentations(args):
    with logger.LoggingBlock("Runtime Augmentations", emph=True):

        training_augmentation = None
        validation_augmentation = None

        # ----------------------------------------------------
        # Training Augmentation
        # ----------------------------------------------------
        if args.training_augmentation is not None:
            kwargs = tools.kwargs_from_args(args, "training_augmentation")
            logging.info("training_augmentation: %s" %
                         args.training_augmentation)
            for param, default in sorted(kwargs.items()):
                logging.info("  %s: %s" % (param, default))
            kwargs["args"] = args
            training_augmentation = tools.instance_from_kwargs(
                args.training_augmentation_class, kwargs)
            if args.cuda:
                training_augmentation = training_augmentation.cuda()

        else:
            logging.info("training_augmentation: None")

        # ----------------------------------------------------
        # Training Augmentation
        # ----------------------------------------------------
        if args.validation_augmentation is not None:
            kwargs = tools.kwargs_from_args(args, "validation_augmentation")
            logging.info("validation_augmentation: %s" %
                         args.training_augmentation)
            for param, default in sorted(kwargs.items()):
                logging.info("  %s: %s" % (param, default))
            kwargs["args"] = args
            validation_augmentation = tools.instance_from_kwargs(
                args.validation_augmentation_class, kwargs)
            if args.cuda:
                validation_augmentation = validation_augmentation.cuda()

        else:
            logging.info("validation_augmentation: None")

    return training_augmentation, validation_augmentation
예제 #7
0
def configure_optimizer(args, model_and_loss):
    optimizer = None
    with logger.LoggingBlock("Optimizer", emph=True):
        if args.optimizer is not None:
            if model_and_loss.num_parameters() == 0:
                logging.info("No trainable parameters detected.")
                logging.info("Setting optimizer to None.")
            else:
                logging.info(args.optimizer)

                # -------------------------------------------
                # Figure out all optimizer arguments
                # -------------------------------------------
                all_kwargs = tools.kwargs_from_args(args, "optimizer")

                # -------------------------------------------
                # Get the split of param groups
                # -------------------------------------------
                kwargs_without_groups = {
                    key: value
                    for key, value in all_kwargs.items() if key != "group"
                }
                param_groups = all_kwargs["group"]

                # ----------------------------------------------------------------------
                # Print arguments (without groups)
                # ----------------------------------------------------------------------
                for param, default in sorted(kwargs_without_groups.items()):
                    logging.info("%s: %s" % (param, default))

                # ----------------------------------------------------------------------
                # Construct actual optimizer params
                # ----------------------------------------------------------------------
                kwargs = dict(kwargs_without_groups)
                if param_groups is None:
                    # ---------------------------------------------------------
                    # Add all trainable parameters if there is no param groups
                    # ---------------------------------------------------------
                    all_trainable_parameters = _generate_trainable_params(
                        model_and_loss)
                    kwargs["params"] = all_trainable_parameters
                else:
                    # -------------------------------------------
                    # Add list of parameter groups instead
                    # -------------------------------------------
                    trainable_parameter_groups = []
                    dnames, dparams = _param_names_and_trainable_generator(
                        model_and_loss)
                    dnames = set(dnames)
                    dparams = set(list(dparams))
                    with logger.LoggingBlock("parameter_groups:"):
                        for group in param_groups:
                            #  log group settings
                            group_match = group["params"]
                            group_args = {
                                key: value
                                for key, value in group.items()
                                if key != "params"
                            }

                            with logger.LoggingBlock(
                                    "%s: %s" % (group_match, group_args)):
                                # retrieve parameters by matching name
                                gnames, gparams = _param_names_and_trainable_generator(
                                    model_and_loss, match=group_match)
                                # log all names affected
                                for n in sorted(gnames):
                                    logging.info(n)
                                # set generator for group
                                group_args["params"] = gparams
                                # append parameter group
                                trainable_parameter_groups.append(group_args)
                                # update remaining trainable parameters
                                dnames -= set(gnames)
                                dparams -= set(list(gparams))

                        # append default parameter group
                        trainable_parameter_groups.append(
                            {"params": list(dparams)})
                        # and log its parameter names
                        with logger.LoggingBlock("default:"):
                            for dname in sorted(dnames):
                                logging.info(dname)

                    # set params in optimizer kwargs
                    kwargs["params"] = trainable_parameter_groups

                # -------------------------------------------
                # Create optimizer instance
                # -------------------------------------------
                optimizer = tools.instance_from_kwargs(args.optimizer_class,
                                                       kwargs)

    return optimizer
예제 #8
0
def configure_data_loaders(args):
    with logger.LoggingBlock("Datasets", emph=True):

        def _sizes_to_str(value):
            if np.isscalar(value):
                return '[1L]'
            else:
                return ' '.join([str([d for d in value.size()])])

        def _log_statistics(dataset, prefix, name):
            with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)):
                example_dict = dataset[
                    0]  # get sizes from first dataset example
                for key, value in sorted(example_dict.items()):
                    if key in ["index",
                               "basename"]:  # no need to display these
                        continue
                    if isinstance(value, str):
                        logging.info("{}: {}".format(key, value))
                    else:
                        logging.info("%s: %s" % (key, _sizes_to_str(value)))
                logging.info("num_examples: %i" % len(dataset))

        # -----------------------------------------------------------------------------------------
        # GPU parameters -- turning off pin_memory? for resolving the deadlock?
        # -----------------------------------------------------------------------------------------
        gpuargs = {
            "num_workers": args.num_workers,
            "pin_memory": False
        } if args.cuda else {}

        train_loader = None
        validation_loader = None
        inference_loader = None

        # -----------------------------------------------------------------------------------------
        # Training dataset
        # -----------------------------------------------------------------------------------------
        if args.training_dataset is not None:
            # ----------------------------------------------
            # Figure out training_dataset arguments
            # ----------------------------------------------
            kwargs = tools.kwargs_from_args(args, "training_dataset")
            kwargs["is_cropped"] = True
            kwargs["args"] = args

            # ----------------------------------------------
            # Create training dataset
            # ----------------------------------------------
            train_dataset = tools.instance_from_kwargs(
                args.training_dataset_class, kwargs)

            # ----------------------------------------------
            # Create training loader
            # ----------------------------------------------
            train_loader = DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      drop_last=False,
                                      **gpuargs)

            _log_statistics(train_dataset,
                            prefix="Training",
                            name=args.training_dataset)

        # -----------------------------------------------------------------------------------------
        # Validation dataset
        # -----------------------------------------------------------------------------------------
        if args.validation_dataset is not None:
            # ----------------------------------------------
            # Figure out validation_dataset arguments
            # ----------------------------------------------
            kwargs = tools.kwargs_from_args(args, "validation_dataset")
            kwargs["is_cropped"] = True
            kwargs["args"] = args

            # ----------------------------------------------
            # Create validation dataset
            # ----------------------------------------------
            validation_dataset = tools.instance_from_kwargs(
                args.validation_dataset_class, kwargs)

            # ----------------------------------------------
            # Create validation loader
            # ----------------------------------------------
            validation_loader = DataLoader(validation_dataset,
                                           batch_size=args.batch_size_val,
                                           shuffle=False,
                                           drop_last=False,
                                           **gpuargs)

            _log_statistics(validation_dataset,
                            prefix="Validation",
                            name=args.validation_dataset)

    return train_loader, validation_loader, inference_loader
예제 #9
0
def configure_model_and_loss(args):
    # ----------------------------------------------------
    # Dynamically load model and loss class with parameters
    # passed in via "--model_[param]=[value]" or "--loss_[param]=[value]" arguments
    # ----------------------------------------------------
    with logger.LoggingBlock("Model and Loss", emph=True):

        # ----------------------------------------------------
        # Model
        # ----------------------------------------------------
        kwargs = tools.kwargs_from_args(args, "model")
        kwargs["args"] = args

        if type(args.checkpoint) == list and len(args.checkpoint) > 1:
            models = nn.ModuleList([
                tools.instance_from_kwargs(args.model_class, kwargs)
                for _ in args.checkpoint
            ])
        else:
            models = tools.instance_from_kwargs(args.model_class, kwargs)

        if hasattr(args, 'avoid_list') and args.avoid_list:
            for model in models:
                model.avoid_list = args.avoid_list.split(',')

        # ----------------------------------------------------
        # Training loss
        # ----------------------------------------------------
        training_loss = None
        if args.training_loss is not None:
            kwargs = tools.kwargs_from_args(args, "training_loss")
            kwargs["args"] = args
            training_loss = tools.instance_from_kwargs(
                args.training_loss_class, kwargs)

        # ----------------------------------------------------
        # Validation loss
        # ----------------------------------------------------
        validation_loss = None
        if args.validation_loss is not None:
            kwargs = tools.kwargs_from_args(args, "validation_loss")
            kwargs["args"] = args
            validation_loss = tools.instance_from_kwargs(
                args.validation_loss_class, kwargs)

        # ----------------------------------------------------
        # Model and loss
        # ----------------------------------------------------
        model_and_loss = ModelAndLoss(args, models, training_loss,
                                      validation_loss)

        # -----------------------------------------------------------
        # If Cuda, transfer model to Cuda and wrap with DataParallel.
        # -----------------------------------------------------------
        if args.cuda:
            model_and_loss = model_and_loss.cuda()

        # ---------------------------------------------------------------
        # Report some network statistics
        # ---------------------------------------------------------------
        logging.info("Batch Size: %i" % args.batch_size)
        logging.info("GPGPU: Cuda") if args.cuda else logging.info(
            "GPGPU: off")
        logging.info("Network: %s" % args.model)
        logging.info("Number of parameters: %i" %
                     tools.x2module(model_and_loss).num_parameters())
        if training_loss is not None:
            logging.info("Training Key: %s" % args.training_key)
            logging.info("Training Loss: %s" % args.training_loss)
        if validation_loss is not None:
            logging.info("Validation Key: %s" % args.validation_key)
            logging.info("Validation Loss: %s" % args.validation_loss)

    return model_and_loss
예제 #10
0
def train(opt, netG):
    # Re-generate dataset frames
    fps, td, fps_index = utils.get_fps_td_by_index(opt.scale_idx, opt)
    opt.fps = fps
    opt.td = td
    opt.fps_index = fps_index

    with logger.LoggingBlock("Updating dataset", emph=True):
        logging.info("{}FPS :{} {}{}".format(green, clear, opt.fps, clear))
        logging.info("{}Time-Depth :{} {}{}".format(green, clear, opt.td,
                                                    clear))
        logging.info("{}Sampling-Ratio :{} {}{}".format(
            green, clear, opt.sampling_rates[opt.fps_index], clear))
        opt.dataset.generate_frames(opt.scale_idx)

    # Initialize noise
    if not hasattr(opt, 'Z_init_size'):
        initial_size = utils.get_scales_by_index(0, opt.scale_factor,
                                                 opt.stop_scale, opt.img_size)
        initial_size = [int(initial_size * opt.ar), initial_size]
        opt.Z_init_size = [
            opt.batch_size, opt.latent_dim, opt.td, *initial_size
        ]

    if opt.vae_levels < opt.scale_idx + 1:
        D_curr = getattr(networks_3d, opt.discriminator)(opt).to(opt.device)

        if (opt.netG != '') and (opt.resumed_idx == opt.scale_idx):
            D_curr.load_state_dict(
                torch.load('{}/netD_{}.pth'.format(
                    opt.resume_dir, opt.scale_idx - 1))['state_dict'])
        elif opt.vae_levels < opt.scale_idx:
            D_curr.load_state_dict(
                torch.load(
                    '{}/netD_{}.pth'.format(opt.saver.experiment_dir,
                                            opt.scale_idx - 1))['state_dict'])

        # Current optimizers
        optimizerD = optim.Adam(D_curr.parameters(),
                                lr=opt.lr_d,
                                betas=(opt.beta1, 0.999))

    parameter_list = []
    # Generator Adversary
    if not opt.train_all:
        if opt.vae_levels < opt.scale_idx + 1:
            train_depth = min(opt.train_depth,
                              len(netG.body) - opt.vae_levels + 1)
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g *
                (opt.lr_scale**(len(netG.body[-train_depth:]) - 1 - idx))
            } for idx, block in enumerate(netG.body[-train_depth:])]
        else:
            # VAE
            parameter_list += [{
                "params": netG.encode.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }, {
                "params": netG.decoder.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }]
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g *
                (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx))
            } for idx, block in enumerate(netG.body[-opt.train_depth:])]
    else:
        if len(netG.body) < opt.train_depth:
            parameter_list += [{
                "params": netG.encode.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }, {
                "params": netG.decoder.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }]
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g * (opt.lr_scale**(len(netG.body) - 1 - idx))
            } for idx, block in enumerate(netG.body)]
        else:
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g *
                (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx))
            } for idx, block in enumerate(netG.body[-opt.train_depth:])]

    optimizerG = optim.Adam(parameter_list,
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))

    # Parallel
    if opt.device == 'cuda':
        G_curr = torch.nn.DataParallel(netG)
        if opt.vae_levels < opt.scale_idx + 1:
            D_curr = torch.nn.DataParallel(D_curr)
    else:
        G_curr = netG

    progressbar_args = {
        "iterable":
        range(opt.niter),
        "desc":
        "Training scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1),
        "train":
        True,
        "offset":
        0,
        "logging_on_update":
        False,
        "logging_on_close":
        True,
        "postfix":
        True
    }
    epoch_iterator = tools.create_progressbar(**progressbar_args)

    iterator = iter(data_loader)

    for iteration in epoch_iterator:
        try:
            data = next(iterator)
        except StopIteration:
            iterator = iter(opt.data_loader)
            data = next(iterator)

        if opt.scale_idx > 0:
            real, real_zero = data
            real = real.to(opt.device)
            real_zero = real_zero.to(opt.device)
        else:
            real = data.to(opt.device)
            real_zero = real

        noise_init = utils.generate_noise(size=opt.Z_init_size,
                                          device=opt.device)

        ############################
        # calculate noise_amp
        ###########################
        if iteration == 0:
            if opt.const_amp:
                opt.Noise_Amps.append(1)
            else:
                with torch.no_grad():
                    if opt.scale_idx == 0:
                        opt.noise_amp = 1
                        opt.Noise_Amps.append(opt.noise_amp)
                    else:
                        opt.Noise_Amps.append(0)
                        z_reconstruction, _, _ = G_curr(real_zero,
                                                        opt.Noise_Amps,
                                                        mode="rec")

                        RMSE = torch.sqrt(F.mse_loss(real, z_reconstruction))
                        opt.noise_amp = opt.noise_amp_init * RMSE.item(
                        ) / opt.batch_size
                        opt.Noise_Amps[-1] = opt.noise_amp

        ############################
        # (1) Update VAE network
        ###########################
        total_loss = 0

        generated, generated_vae, (mu, logvar) = G_curr(real_zero,
                                                        opt.Noise_Amps,
                                                        mode="rec")

        if opt.vae_levels >= opt.scale_idx + 1:
            rec_vae_loss = opt.rec_loss(generated, real) + opt.rec_loss(
                generated_vae, real_zero)
            kl_loss = kl_criterion(mu, logvar)
            vae_loss = opt.rec_weight * rec_vae_loss + opt.kl_weight * kl_loss

            total_loss += vae_loss
        else:
            ############################
            # (2) Update D network: maximize D(x) + D(G(z))
            ###########################
            # train with real
            #################

            # Train 3D Discriminator
            D_curr.zero_grad()
            output = D_curr(real)
            errD_real = -output.mean()

            # train with fake
            #################
            fake, _ = G_curr(noise_init,
                             opt.Noise_Amps,
                             noise_init=noise_init,
                             mode="rand")

            # Train 3D Discriminator
            output = D_curr(fake.detach())
            errD_fake = output.mean()

            gradient_penalty = calc_gradient_penalty(D_curr, real, fake,
                                                     opt.lambda_grad,
                                                     opt.device)
            errD_total = errD_real + errD_fake + gradient_penalty
            errD_total.backward()
            optimizerD.step()

            ############################
            # (3) Update G network: maximize D(G(z))
            ###########################
            errG_total = 0
            rec_loss = opt.rec_loss(generated, real)
            errG_total += opt.rec_weight * rec_loss

            # Train with 3D Discriminator
            output = D_curr(fake)
            errG = -output.mean() * opt.disc_loss_weight
            errG_total += errG

            total_loss += errG_total

        G_curr.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(G_curr.parameters(), opt.grad_clip)
        optimizerG.step()

        # Update progress bar
        epoch_iterator.set_description(
            'Scale [{}/{}], Iteration [{}/{}]'.format(
                opt.scale_idx + 1,
                opt.stop_scale + 1,
                iteration + 1,
                opt.niter,
            ))

        if opt.visualize:
            # Tensorboard
            opt.summary.add_scalar(
                'Video/Scale {}/noise_amp'.format(opt.scale_idx),
                opt.noise_amp, iteration)
            if opt.vae_levels >= opt.scale_idx + 1:
                opt.summary.add_scalar(
                    'Video/Scale {}/KLD'.format(opt.scale_idx), kl_loss.item(),
                    iteration)
            else:
                opt.summary.add_scalar(
                    'Video/Scale {}/rec loss'.format(opt.scale_idx),
                    rec_loss.item(), iteration)
            opt.summary.add_scalar(
                'Video/Scale {}/noise_amp'.format(opt.scale_idx),
                opt.noise_amp, iteration)
            if opt.vae_levels < opt.scale_idx + 1:
                opt.summary.add_scalar(
                    'Video/Scale {}/errG'.format(opt.scale_idx), errG.item(),
                    iteration)
                opt.summary.add_scalar(
                    'Video/Scale {}/errD_fake'.format(opt.scale_idx),
                    errD_fake.item(), iteration)
                opt.summary.add_scalar(
                    'Video/Scale {}/errD_real'.format(opt.scale_idx),
                    errD_real.item(), iteration)
            else:
                opt.summary.add_scalar(
                    'Video/Scale {}/Rec VAE'.format(opt.scale_idx),
                    rec_vae_loss.item(), iteration)

            if iteration % opt.print_interval == 0:
                with torch.no_grad():
                    fake_var = []
                    fake_vae_var = []
                    for _ in range(3):
                        noise_init = utils.generate_noise(ref=noise_init)
                        fake, fake_vae = G_curr(noise_init,
                                                opt.Noise_Amps,
                                                noise_init=noise_init,
                                                mode="rand")
                        fake_var.append(fake)
                        fake_vae_var.append(fake_vae)
                    fake_var = torch.cat(fake_var, dim=0)
                    fake_vae_var = torch.cat(fake_vae_var, dim=0)

                opt.summary.visualize_video(opt, iteration, real, 'Real')
                opt.summary.visualize_video(opt, iteration, generated,
                                            'Generated')
                opt.summary.visualize_video(opt, iteration, generated_vae,
                                            'Generated VAE')
                opt.summary.visualize_video(opt, iteration, fake_var,
                                            'Fake var')
                opt.summary.visualize_video(opt, iteration, fake_vae_var,
                                            'Fake VAE var')

    epoch_iterator.close()

    # Save data
    opt.saver.save_checkpoint({'data': opt.Noise_Amps}, 'Noise_Amps.pth')
    opt.saver.save_checkpoint(
        {
            'scale': opt.scale_idx,
            'state_dict': netG.state_dict(),
            'optimizer': optimizerG.state_dict(),
            'noise_amps': opt.Noise_Amps,
        }, 'netG.pth')
    if opt.vae_levels < opt.scale_idx + 1:
        opt.saver.save_checkpoint(
            {
                'scale':
                opt.scale_idx,
                'state_dict':
                D_curr.module.state_dict()
                if opt.device == 'cuda' else D_curr.state_dict(),
                'optimizer':
                optimizerD.state_dict(),
            }, 'netD_{}.pth'.format(opt.scale_idx))
예제 #11
0
                             batch_size=opt.batch_size,
                             num_workers=4)

    if opt.stop_scale_time == -1:
        opt.stop_scale_time = opt.stop_scale

    opt.dataset = dataset
    opt.data_loader = data_loader

    with open(os.path.join(opt.saver.experiment_dir, 'args.txt'),
              'w') as args_file:
        for argument, value in sorted(vars(opt).items()):
            if type(value) in (str, int, float, tuple, list, bool):
                args_file.write('{}: {}\n'.format(argument, value))

    with logger.LoggingBlock("Commandline Arguments", emph=True):
        for argument, value in sorted(vars(opt).items()):
            if type(value) in (str, int, float, tuple, list):
                logging.info('{}: {}'.format(argument, value))

    with logger.LoggingBlock("Experiment Summary", emph=True):
        video_file_name, checkname, experiment = opt.saver.experiment_dir.split(
            '/')[-3:]
        logging.info("{}Video file :{} {}{}".format(magenta, clear,
                                                    video_file_name, clear))
        logging.info("{}Checkname  :{} {}{}".format(magenta, clear, checkname,
                                                    clear))
        logging.info("{}Experiment :{} {}{}".format(magenta, clear, experiment,
                                                    clear))

        with logger.LoggingBlock("Commandline Summary", emph=True):
예제 #12
0
def exec_runtime(args, checkpoint_saver, model_and_loss, optimizer,
                 lr_scheduler, train_loader, validation_loader,
                 inference_loader, training_augmentation,
                 validation_augmentation):
    # ----------------------------------------------------------------------------------------------
    # Validation schedulers are a bit special:
    # They want to be called with a validation loss..
    # ----------------------------------------------------------------------------------------------
    validation_scheduler = (lr_scheduler is not None
                            and args.lr_scheduler == "ReduceLROnPlateau")

    # --------------------------------------------------------
    # Log some runtime info
    # --------------------------------------------------------
    with logger.LoggingBlock("Runtime", emph=True):
        logging.info("start_epoch: %i" % args.start_epoch)
        logging.info("total_epochs: %i" % args.total_epochs)

    # ---------------------------------------
    # Total progress bar arguments
    # ---------------------------------------
    progressbar_args = {
        "desc": "Progress",
        "initial": args.start_epoch - 1,
        "invert_iterations": True,
        "iterable": range(1, args.total_epochs + 1),
        "logging_on_close": True,
        "logging_on_update": True,
        "postfix": False,
        "unit": "ep"
    }

    # --------------------------------------------------------
    # Total progress bar
    # --------------------------------------------------------
    print(''), logging.logbook('')
    total_progress = create_progressbar(**progressbar_args)
    print("\n")

    # --------------------------------------------------------
    # Remember validation loss
    # --------------------------------------------------------
    best_validation_loss = float(
        "inf") if args.validation_key_minimize else -float("inf")
    store_as_best = False

    for epoch in range(args.start_epoch, args.total_epochs + 1):
        with logger.LoggingBlock("Epoch %i/%i" % (epoch, args.total_epochs),
                                 emph=True):

            # --------------------------------------------------------
            # Update standard learning scheduler
            # --------------------------------------------------------
            if lr_scheduler is not None and not validation_scheduler:
                lr_scheduler.step(epoch)

            # --------------------------------------------------------
            # Always report learning rate
            # --------------------------------------------------------
            if lr_scheduler is None:
                logging.info("lr: %s" %
                             format_learning_rate(args.optimizer_lr))
            else:
                logging.info("lr: %s" %
                             format_learning_rate(lr_scheduler.get_lr()))

            # -------------------------------------------
            # Create and run a training epoch
            # -------------------------------------------
            if train_loader is not None:
                avg_loss_dict = TrainingEpoch(
                    args,
                    desc="   Train",
                    model_and_loss=model_and_loss,
                    optimizer=optimizer,
                    loader=train_loader,
                    augmentation=training_augmentation).run()

            # -------------------------------------------
            # Create and run a validation epoch
            # -------------------------------------------
            if validation_loader is not None:

                # ---------------------------------------------------
                # Construct holistic recorder for epoch
                # ---------------------------------------------------
                avg_loss_dict = EvaluationEpoch(
                    args,
                    desc="Validate",
                    model_and_loss=model_and_loss,
                    loader=validation_loader,
                    augmentation=validation_augmentation).run()

                # ----------------------------------------------------------------
                # Evaluate whether this is the best validation_loss
                # ----------------------------------------------------------------
                validation_loss = avg_loss_dict[args.validation_key]
                if args.validation_key_minimize:
                    store_as_best = validation_loss < best_validation_loss
                else:
                    store_as_best = validation_loss > best_validation_loss
                if store_as_best:
                    best_validation_loss = validation_loss

                # ----------------------------------------------------------------
                # Update validation scheduler, if one is in place
                # ----------------------------------------------------------------
                if lr_scheduler is not None and validation_scheduler:
                    lr_scheduler.step(validation_loss, epoch=epoch)

            # ----------------------------------------------------------------
            # Also show best loss on total_progress
            # ----------------------------------------------------------------
            total_progress_stats = {
                "best_" + args.validation_key + "_avg":
                "%1.4f" % best_validation_loss
            }
            total_progress.set_postfix(total_progress_stats)

            # ----------------------------------------------------------------
            # Bump total progress
            # ----------------------------------------------------------------
            total_progress.update()
            print('')

            # ----------------------------------------------------------------
            # Store checkpoint
            # ----------------------------------------------------------------
            if checkpoint_saver is not None:
                if args.max_save and (args.min_save <= epoch <= args.max_save):
                    checkpoint_saver.save_latest(
                        directory=args.save,
                        model_and_loss=model_and_loss,
                        stats_dict=dict(avg_loss_dict, epoch=epoch),
                        suffix="_epoch_{}".format(epoch))

                checkpoint_saver.save_latest(directory=args.save,
                                             model_and_loss=model_and_loss,
                                             stats_dict=dict(avg_loss_dict,
                                                             epoch=epoch),
                                             store_as_best=store_as_best)

            # ----------------------------------------------------------------
            # Vertical space between epochs
            # ----------------------------------------------------------------
            print(''), logging.logbook('')

    # ----------------------------------------------------------------
    # Finish
    # ----------------------------------------------------------------
    total_progress.close()
    logging.info("Finished.")

    return avg_loss_dict