Example #1
0
def configure_lr_scheduler(args, optimizer):
    lr_scheduler = None

    with logger.LoggingBlock("Learning Rate Scheduler", emph=True):
        logging.info("class: %s" % args.lr_scheduler)

        if args.lr_scheduler is not None:

            # ----------------------------------------------
            # Figure out lr_scheduler arguments
            # ----------------------------------------------
            kwargs = tools.kwargs_from_args(args, "lr_scheduler")

            # -------------------------------------------
            # Print arguments
            # -------------------------------------------
            for param, default in sorted(kwargs.items()):
                logging.info("%s: %s" % (param, default))

            # -------------------------------------------
            # Add optimizer
            # -------------------------------------------
            kwargs["optimizer"] = optimizer

            # -------------------------------------------
            # Create lr_scheduler instance
            # -------------------------------------------
            lr_scheduler = tools.instance_from_kwargs(args.lr_scheduler_class,
                                                      kwargs)

    return lr_scheduler
Example #2
0
def configure_attack(args):
    attack = None
    with logger.LoggingBlock("Adversarial Attack:", emph=True):

        if args.attack is not None:

            # ----------------------------------------------
            # Figure out keyword arguments
            # ----------------------------------------------
            kwargs = typeinf.kwargs_from_args(args, "attack")

            # -------------------------------------------
            # Log arguments
            # -------------------------------------------
            logging.info("%s" % args.attack)
            for param, default in sorted(kwargs.items()):
                logging.info("%s: %s" % (param, default))

            # -------------------------------------------
            # Create instance
            # -------------------------------------------
            kwargs["args"] = args
            attack = typeinf.instance_from_kwargs(args.attack_class, kwargs)
        else:
            logging.info("None")

    return attack
Example #3
0
def setup_logging_and_parse_arguments(blocktitle):
    # ----------------------------------------------------------------------------
    # Get parse commandline and default arguments
    # ----------------------------------------------------------------------------
    args, defaults = _parse_arguments()

    # ----------------------------------------------------------------------------
    # Setup logbook before everything else
    # ----------------------------------------------------------------------------
    logger.configure_logging(os.path.join(args.save, 'logbook.txt'))

    # ----------------------------------------------------------------------------
    # Write arguments to file, as txt
    # ----------------------------------------------------------------------------
    tools.write_dictionary_to_file(sorted(vars(args).items()),
                                   filename=os.path.join(
                                       args.save, 'args.txt'))

    # ----------------------------------------------------------------------------
    # Log arguments
    # ----------------------------------------------------------------------------
    with logger.LoggingBlock(blocktitle, emph=True):
        for argument, value in sorted(vars(args).items()):
            reset = colorama.Style.RESET_ALL
            color = reset if value == defaults[argument] else colorama.Fore.CYAN
            logging.info('{}{}: {}{}'.format(color, argument, value, reset))

    # ----------------------------------------------------------------------------
    # Postprocess
    # ----------------------------------------------------------------------------
    args = postprocess_args(args)

    return args
Example #4
0
def configure_model_and_loss(args):

    # ----------------------------------------------------
    # Dynamically load model and loss class with parameters
    # passed in via "--model_[param]=[value]" or "--loss_[param]=[value]" arguments
    # ----------------------------------------------------
    with logger.LoggingBlock("Model and Loss", emph=True):

        # ----------------------------------------------------
        # Model
        # ----------------------------------------------------
        kwargs = tools.kwargs_from_args(args, "model")
        kwargs["args"] = args
        model = tools.instance_from_kwargs(args.model_class, kwargs)

        # ----------------------------------------------------
        # Training loss
        # ----------------------------------------------------
        training_loss = None
        if args.training_loss is not None:
            kwargs = tools.kwargs_from_args(args, "training_loss")
            kwargs["args"] = args
            training_loss = tools.instance_from_kwargs(args.training_loss_class, kwargs)

        # ----------------------------------------------------
        # Validation loss
        # ----------------------------------------------------
        validation_loss = None
        if args.validation_loss is not None:
            kwargs = tools.kwargs_from_args(args, "validation_loss")
            kwargs["args"] = args
            validation_loss = tools.instance_from_kwargs(args.validation_loss_class, kwargs)

        # ----------------------------------------------------
        # Model and loss
        # ----------------------------------------------------
        model_and_loss = ModelAndLoss(args, model, training_loss, validation_loss)

        # -----------------------------------------------------------
        # If Cuda, transfer model to Cuda and wrap with DataParallel.
        # -----------------------------------------------------------
        if args.cuda:
            model_and_loss = model_and_loss.cuda()

        # ---------------------------------------------------------------
        # Report some network statistics
        # ---------------------------------------------------------------
        logging.info("Batch Size: %i" % args.batch_size)
        logging.info("GPGPU: Cuda") if args.cuda else logging.info("GPGPU: off")
        logging.info("Network: %s" % args.model)
        logging.info("Number of parameters: %i" % tools.x2module(model_and_loss).num_parameters())
        if training_loss is not None:
            logging.info("Training Key: %s" % args.training_key)
            logging.info("Training Loss: %s" % args.training_loss)
        if validation_loss is not None:
            logging.info("Validation Key: %s" % args.validation_key)
            logging.info("Validation Loss: %s" % args.validation_loss)

    return model_and_loss
Example #5
0
 def _log_statistics(dataset, prefix, name):
     with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)):
         example_dict = dataset[0]  # get sizes from first dataset example
         for key, value in sorted(example_dict.items()):
             if key in ["index", "basename"]:  # no need to display these
                 continue
             if isinstance(value, str):
                 logging.info("{}: {}".format(key, value))
             else:
                 logging.info("%s: %s" % (key, _sizes_to_str(value)))
         logging.info("num_examples: %i" % len(dataset))
Example #6
0
def configure_checkpoint_saver(args, model_and_loss):
    with logger.LoggingBlock("Checkpoint", emph=True):
        checkpoint_saver = CheckpointSaver()
        checkpoint_stats = None

        if args.checkpoint is None:
            logging.info("No checkpoint given.")
            logging.info("Starting from scratch with random initialization.")

        elif os.path.isfile(args.checkpoint):
            logging.info("Loading checkpoint file %s" % args.checkpoint)
            checkpoint_stats, filename = checkpoint_saver.restore(
                filename=args.checkpoint,
                model_and_loss=model_and_loss,
                include_params=args.checkpoint_include_params,
                exclude_params=args.checkpoint_exclude_params)

            # load epoch number
            no_extension = filename.split('.')[0]
            statistics_filename = no_extension + ".json"
            statistics = tools.read_json(statistics_filename)
            args.start_epoch = statistics['epoch'] + 1

        elif os.path.isdir(args.checkpoint):
            if args.checkpoint_mode in ["resume_from_best"]:
                logging.info("Loading best checkpoint in %s" % args.checkpoint)
                checkpoint_stats, filename = checkpoint_saver.restore_best(
                    directory=args.checkpoint,
                    model_and_loss=model_and_loss,
                    include_params=args.checkpoint_include_params,
                    exclude_params=args.checkpoint_exclude_params)

            elif args.checkpoint_mode in ["resume_from_latest"]:
                logging.info("Loading latest checkpoint in %s" % args.checkpoint)
                checkpoint_stats, filename = checkpoint_saver.restore_latest(
                    directory=args.checkpoint,
                    model_and_loss=model_and_loss,
                    include_params=args.checkpoint_include_params,
                    exclude_params=args.checkpoint_exclude_params)
            else:
                logging.info("Unknown checkpoint_restore '%s' given!" % args.checkpoint_restore)
                quit()
        else:
            logging.info("Could not find checkpoint file or directory '%s'. Starting with random initialization." % args.checkpoint)

    return checkpoint_saver, checkpoint_stats
Example #7
0
def configure_random_seed(args):
    with logger.LoggingBlock("Random Seeds", emph=True):
        # python
        seed = args.seed
        random.seed(seed)
        logging.info("Python seed: %i" % seed)
        # numpy
        seed += 1
        np.random.seed(seed)
        logging.info("Numpy seed: %i" % seed)
        # torch
        seed += 1
        torch.manual_seed(seed)
        logging.info("Torch CPU seed: %i" % seed)
        # torch cuda
        seed += 1
        torch.cuda.manual_seed(seed)
        logging.info("Torch CUDA seed: %i" % seed)
Example #8
0
def setup_logging_and_parse_arguments(blocktitle):
    # ----------------------------------------------------------------------------
    # Get parse commandline and default arguments
    # ----------------------------------------------------------------------------
    args, defaults = _parse_arguments()

    # ----------------------------------------------------------------------------
    # Setup logbook before everything else
    # ----------------------------------------------------------------------------
    logger.configure_logging(os.path.join(args.save, "logbook.txt"))

    # ----------------------------------------------------------------------------
    # Write arguments to file, as json and txt
    # ----------------------------------------------------------------------------
    json.write_dictionary_to_file(vars(args),
                                  filename=os.path.join(
                                      args.save, "args.json"),
                                  sortkeys=True)
    json.write_dictionary_to_file(vars(args),
                                  filename=os.path.join(args.save, "args.txt"),
                                  sortkeys=True)

    # ----------------------------------------------------------------------------
    # Log arguments
    # ----------------------------------------------------------------------------
    with logger.LoggingBlock(blocktitle, emph=True):
        for argument, value in sorted(vars(args).items()):
            reset = colorama.Style.RESET_ALL
            color = reset if value == defaults[argument] else colorama.Fore.CYAN
            if isinstance(value, dict):
                for sub_argument, sub_value in collections.OrderedDict(
                        value).items():
                    logging.info("{}{}_{}: {}{}".format(
                        color, argument, sub_argument, sub_value, reset))
            else:
                logging.info("{}{}: {}{}".format(color, argument, value,
                                                 reset))

    # ----------------------------------------------------------------------------
    # Postprocess
    # ----------------------------------------------------------------------------
    args = postprocess_args(args)

    return args
Example #9
0
def configure_runtime_augmentations(args):
    with logger.LoggingBlock("Runtime Augmentations", emph=True):

        training_augmentation = None
        validation_augmentation = None

        # ----------------------------------------------------
        # Training Augmentation
        # ----------------------------------------------------
        if args.training_augmentation is not None:
            kwargs = tools.kwargs_from_args(args, "training_augmentation")
            logging.info("training_augmentation: %s" %
                         args.training_augmentation)
            for param, default in sorted(kwargs.items()):
                logging.info("  %s: %s" % (param, default))
            kwargs["args"] = args
            training_augmentation = tools.instance_from_kwargs(
                args.training_augmentation_class, kwargs)
            if args.cuda:
                training_augmentation = training_augmentation.cuda()

        else:
            logging.info("training_augmentation: None")

        # ----------------------------------------------------
        # Validation Augmentation
        # ----------------------------------------------------
        if args.validation_augmentation is not None:
            kwargs = tools.kwargs_from_args(args, "validation_augmentation")
            logging.info("validation_augmentation: %s" %
                         args.validation_augmentation)
            for param, default in sorted(kwargs.items()):
                logging.info("  %s: %s" % (param, default))
            kwargs["args"] = args
            validation_augmentation = tools.instance_from_kwargs(
                args.validation_augmentation_class, kwargs)
            if args.cuda:
                validation_augmentation = validation_augmentation.cuda()

        else:
            logging.info("validation_augmentation: None")

    return training_augmentation, validation_augmentation
Example #10
0
def configure_checkpoint_saver(args, model_and_loss):
    with logger.LoggingBlock("Checkpoint", emph=True):
        checkpoint_saver = CheckpointSaver()
        checkpoint_stats = None

        if args.checkpoint is None:
            logging.info("No checkpoint given.")
            logging.info("Starting from scratch with random initialization.")

        elif os.path.isfile(args.checkpoint):
            checkpoint_stats, filename = checkpoint_saver.restore(
                filename=args.checkpoint,
                model_and_loss=model_and_loss,
                include_params=args.checkpoint_include_params,
                exclude_params=args.checkpoint_exclude_params)

        elif os.path.isdir(args.checkpoint):
            if args.checkpoint_mode in ["resume_from_best"]:
                logging.info("Loading best checkpoint in %s" % args.checkpoint)
                checkpoint_stats, filename = checkpoint_saver.restore_best(
                    directory=args.checkpoint,
                    model_and_loss=model_and_loss,
                    include_params=args.checkpoint_include_params,
                    exclude_params=args.checkpoint_exclude_params)

            elif args.checkpoint_mode in ["resume_from_latest"]:
                logging.info("Loading latest checkpoint in %s" %
                             args.checkpoint)
                checkpoint_stats, filename = checkpoint_saver.restore_latest(
                    directory=args.checkpoint,
                    model_and_loss=model_and_loss,
                    include_params=args.checkpoint_include_params,
                    exclude_params=args.checkpoint_exclude_params)
            else:
                logging.info("Unknown checkpoint_restore '%s' given!" %
                             args.checkpoint_restore)
                quit()
        else:
            logging.info("Could not find checkpoint file or directory '%s'" %
                         args.checkpoint)
            quit()

    return checkpoint_saver, checkpoint_stats
Example #11
0
def configure_model_and_loss(args):
    # ----------------------------------------------------
    # Dynamically load model and loss class with parameters
    # passed in via "--model_[param]=[value]" or "--loss_[param]=[value]" arguments
    # ----------------------------------------------------
    with logger.LoggingBlock("Model and Loss", emph=True):

        # ----------------------------------------------------
        # Model
        # ----------------------------------------------------
        kwargs = typeinf.kwargs_from_args(args, "model")
        kwargs["args"] = args
        model = typeinf.instance_from_kwargs(args.model_class, kwargs)

        # ----------------------------------------------------
        # Training loss
        # ----------------------------------------------------
        loss = None
        if args.loss is not None:
            kwargs = typeinf.kwargs_from_args(args, "loss")
            kwargs["args"] = args
            loss = typeinf.instance_from_kwargs(args.loss_class, kwargs)

        # ----------------------------------------------------
        # Model and loss
        # ----------------------------------------------------
        model_and_loss = ModelAndLoss(args, model, loss)

        # ---------------------------------------------------------------
        # Report some network statistics
        # ---------------------------------------------------------------
        logging.info("Batch Size: %i" % args.batch_size)
        logging.info("Network: %s" % args.model)
        logging.info("Number of parameters: %i" % model_and_loss.num_parameters())
        if loss is not None:
            logging.info("Training Key: %s" % args.training_key)
            logging.info("Training Loss: %s" % args.loss)
        logging.info("Validation Keys: %s" % args.validation_keys)
        logging.info("Validation Keys Minimize: %s" % args.validation_keys_minimize)

    return model_and_loss
Example #12
0
def exec_runtime(args, device, checkpoint_saver, model_and_loss, optimizer,
                 attack, lr_scheduler, train_loader, validation_loader,
                 inference_loader, training_augmentation,
                 validation_augmentation):
    # ----------------------------------------------------------------------------------------------
    # Validation schedulers are a bit special:
    # They want to be called with a validation loss..
    # ----------------------------------------------------------------------------------------------
    validation_scheduler = (lr_scheduler is not None
                            and args.lr_scheduler == "ReduceLROnPlateau")

    # --------------------------------------------------------
    # Log some runtime info
    # --------------------------------------------------------
    with logger.LoggingBlock("Runtime", emph=True):
        logging.info("start_epoch: %i" % args.start_epoch)
        logging.info("total_epochs: %i" % args.total_epochs)

    # ---------------------------------------
    # Total progress bar arguments
    # ---------------------------------------
    progressbar_args = {
        "desc": "Progress",
        "initial": args.start_epoch - 1,
        "invert_iterations": True,
        "iterable": range(1, args.total_epochs + 1),
        "logging_on_close": True,
        "logging_on_update": True,
        "postfix": False,
        "unit": "ep"
    }

    # --------------------------------------------------------
    # Total progress bar
    # --------------------------------------------------------
    print(''), logging.logbook('')
    total_progress = create_progressbar(**progressbar_args)
    print("\n")

    # --------------------------------------------------------
    # Remember validation losses
    # --------------------------------------------------------
    num_validation_losses = len(args.validation_keys)
    best_validation_losses = [
        float("inf") if args.validation_keys_minimize[i] else -float("inf")
        for i in range(num_validation_losses)
    ]
    store_as_best = [False for i in range(num_validation_losses)]

    # --------------------------------------------------------
    # Transfer model to device once before training/evaluation
    # --------------------------------------------------------
    model_and_loss = model_and_loss.to(device)

    avg_loss_dict = {}
    for epoch in range(args.start_epoch, args.total_epochs + 1):
        with logger.LoggingBlock("Epoch %i/%i" % (epoch, args.total_epochs),
                                 emph=True):

            # --------------------------------------------------------
            # Update standard learning scheduler
            # --------------------------------------------------------
            if lr_scheduler is not None and not validation_scheduler:
                lr_scheduler.step(epoch)

            # --------------------------------------------------------
            # Always report learning rate and model
            # --------------------------------------------------------
            if lr_scheduler is None:
                logging.info(
                    "model: %s  lr: %s" %
                    (args.model, format_learning_rate(args.optimizer_lr)))
            else:
                logging.info(
                    "model: %s  lr: %s" %
                    (args.model, format_learning_rate(lr_scheduler.get_lr())))

            # -------------------------------------------
            # Create and run a training epoch
            # -------------------------------------------
            if train_loader is not None:
                avg_loss_dict, _ = TrainingEpoch(
                    args,
                    desc="   Train",
                    device=device,
                    model_and_loss=model_and_loss,
                    optimizer=optimizer,
                    loader=train_loader,
                    augmentation=training_augmentation).run()

            # -------------------------------------------
            # Create and run a validation epoch
            # -------------------------------------------
            if validation_loader is not None:
                # ---------------------------------------------------
                # Construct holistic recorder for epoch
                # ---------------------------------------------------
                epoch_recorder = configure_holistic_epoch_recorder(
                    args, epoch=epoch, loader=validation_loader)

                with torch.no_grad():
                    avg_loss_dict, output_dict = EvaluationEpoch(
                        args,
                        desc="Validate",
                        device=device,
                        model_and_loss=model_and_loss,
                        attack=attack,
                        loader=validation_loader,
                        recorder=epoch_recorder,
                        augmentation=validation_augmentation).run()

                # ----------------------------------------------------------------
                # Evaluate valdiation losses
                # ----------------------------------------------------------------
                validation_losses = [
                    avg_loss_dict[vkey] for vkey in args.validation_keys
                ]
                for i, (vkey, vminimize) in enumerate(
                        zip(args.validation_keys,
                            args.validation_keys_minimize)):
                    if vminimize:
                        store_as_best[i] = validation_losses[
                            i] < best_validation_losses[i]
                    else:
                        store_as_best[i] = validation_losses[
                            i] > best_validation_losses[i]
                    if store_as_best[i]:
                        best_validation_losses[i] = validation_losses[i]

                # ----------------------------------------------------------------
                # Update validation scheduler, if one is in place
                # We use the first key in validation keys as the relevant one
                # ----------------------------------------------------------------
                if lr_scheduler is not None and validation_scheduler:
                    lr_scheduler.step(validation_losses[0], epoch=epoch)

            # ----------------------------------------------------------------
            # Also show best loss on total_progress
            # ----------------------------------------------------------------
            total_progress_stats = {
                "best_" + vkey + "_avg": "%1.4f" % best_validation_losses[i]
                for i, vkey in enumerate(args.validation_keys)
            }
            total_progress.set_postfix(total_progress_stats)

            # ----------------------------------------------------------------
            # Bump total progress
            # ----------------------------------------------------------------
            total_progress.update()
            print('')

            # ----------------------------------------------------------------
            # Store checkpoint
            # ----------------------------------------------------------------
            if checkpoint_saver is not None:
                checkpoint_saver.save_latest(
                    directory=args.save,
                    model_and_loss=model_and_loss,
                    stats_dict=dict(avg_loss_dict, epoch=epoch),
                    store_as_best=store_as_best,
                    store_prefixes=args.validation_keys)

            # ----------------------------------------------------------------
            # Vertical space between epochs
            # ----------------------------------------------------------------
            print(''), logging.logbook('')

    # ----------------------------------------------------------------
    # Finish
    # ----------------------------------------------------------------
    total_progress.close()
    logging.info("Finished.")
Example #13
0
def configure_optimizer(args, model_and_loss):
    optimizer = None
    with logger.LoggingBlock("Optimizer", emph=True):
        if args.optimizer is not None:
            if model_and_loss.num_parameters() == 0:
                logging.info("No trainable parameters detected.")
                logging.info("Setting optimizer to None.")
            else:
                logging.info(args.optimizer)

                # -------------------------------------------
                # Figure out all optimizer arguments
                # -------------------------------------------
                all_kwargs = tools.kwargs_from_args(args, "optimizer")

                # -------------------------------------------
                # Get the split of param groups
                # -------------------------------------------
                kwargs_without_groups = {
                    key: value
                    for key, value in all_kwargs.items() if key != "group"
                }
                param_groups = all_kwargs["group"]

                # ----------------------------------------------------------------------
                # Print arguments (without groups)
                # ----------------------------------------------------------------------
                for param, default in sorted(kwargs_without_groups.items()):
                    logging.info("%s: %s" % (param, default))

                # ----------------------------------------------------------------------
                # Construct actual optimizer params
                # ----------------------------------------------------------------------
                kwargs = dict(kwargs_without_groups)
                if param_groups is None:
                    # ---------------------------------------------------------
                    # Add all trainable parameters if there is no param groups
                    # ---------------------------------------------------------
                    all_trainable_parameters = _generate_trainable_params(
                        model_and_loss)
                    kwargs["params"] = all_trainable_parameters
                else:
                    # -------------------------------------------
                    # Add list of parameter groups instead
                    # -------------------------------------------
                    trainable_parameter_groups = []
                    dnames, dparams = _param_names_and_trainable_generator(
                        model_and_loss)
                    dnames = set(dnames)
                    dparams = set(list(dparams))
                    with logger.LoggingBlock("parameter_groups:"):
                        for group in param_groups:
                            #  log group settings
                            group_match = group["params"]
                            group_args = {
                                key: value
                                for key, value in group.items()
                                if key != "params"
                            }

                            with logger.LoggingBlock(
                                    "%s: %s" % (group_match, group_args)):
                                # retrieve parameters by matching name
                                gnames, gparams = _param_names_and_trainable_generator(
                                    model_and_loss, match=group_match)
                                # log all names affected
                                for n in sorted(gnames):
                                    logging.info(n)
                                # set generator for group
                                group_args["params"] = gparams
                                # append parameter group
                                trainable_parameter_groups.append(group_args)
                                # update remaining trainable parameters
                                dnames -= set(gnames)
                                dparams -= set(list(gparams))

                        # append default parameter group
                        trainable_parameter_groups.append(
                            {"params": list(dparams)})
                        # and log its parameter names
                        with logger.LoggingBlock("default:"):
                            for dname in sorted(dnames):
                                logging.info(dname)

                    # set params in optimizer kwargs
                    kwargs["params"] = trainable_parameter_groups

                # -------------------------------------------
                # Create optimizer instance
                # -------------------------------------------
                optimizer = tools.instance_from_kwargs(args.optimizer_class,
                                                       kwargs)

    return optimizer
Example #14
0
def configure_data_loaders(args):
    with logger.LoggingBlock("Datasets", emph=True):

        def _sizes_to_str(value):
            if np.isscalar(value):
                return '[1L]'
            else:
                return ' '.join([str([d for d in value.size()])])

        def _log_statistics(dataset, prefix, name):
            with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)):
                example_dict = dataset[
                    0]  # get sizes from first dataset example
                for key, value in sorted(example_dict.items()):
                    if key in ["index",
                               "basename"]:  # no need to display these
                        continue
                    if isinstance(value, str):
                        logging.info("{}: {}".format(key, value))
                    else:
                        logging.info("%s: %s" % (key, _sizes_to_str(value)))
                logging.info("num_examples: %i" % len(dataset))

        # -----------------------------------------------------------------------------------------
        # GPU parameters -- turning off pin_memory? for resolving the deadlock?
        # -----------------------------------------------------------------------------------------
        gpuargs = {
            "num_workers": args.num_workers,
            "pin_memory": False
        } if args.cuda else {}

        train_loader = None
        validation_loader = None
        inference_loader = None

        # -----------------------------------------------------------------------------------------
        # Training dataset
        # -----------------------------------------------------------------------------------------
        if args.training_dataset is not None:

            # ----------------------------------------------
            # Figure out training_dataset arguments
            # ----------------------------------------------
            kwargs = tools.kwargs_from_args(args, "training_dataset")
            kwargs["is_cropped"] = True
            kwargs["args"] = args

            # ----------------------------------------------
            # Create training dataset
            # ----------------------------------------------
            train_dataset = tools.instance_from_kwargs(
                args.training_dataset_class, kwargs)

            # ----------------------------------------------
            # Create training loader
            # ----------------------------------------------
            train_loader = DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      drop_last=False,
                                      **gpuargs)

            _log_statistics(train_dataset,
                            prefix="Training",
                            name=args.training_dataset)

        # -----------------------------------------------------------------------------------------
        # Validation dataset
        # -----------------------------------------------------------------------------------------
        if args.validation_dataset is not None:

            # ----------------------------------------------
            # Figure out validation_dataset arguments
            # ----------------------------------------------
            kwargs = tools.kwargs_from_args(args, "validation_dataset")
            kwargs["is_cropped"] = True
            kwargs["args"] = args

            # ----------------------------------------------
            # Create validation dataset
            # ----------------------------------------------
            validation_dataset = tools.instance_from_kwargs(
                args.validation_dataset_class, kwargs)

            # ----------------------------------------------
            # Create validation loader
            # ----------------------------------------------
            validation_loader = DataLoader(validation_dataset,
                                           batch_size=args.batch_size_val,
                                           shuffle=False,
                                           drop_last=False,
                                           **gpuargs)

            _log_statistics(validation_dataset,
                            prefix="Validation",
                            name=args.validation_dataset)

    return train_loader, validation_loader, inference_loader
Example #15
0
def main():
    # ----------------------------------------------------
    # Change working directory
    # ----------------------------------------------------
    os.chdir(os.path.dirname(os.path.realpath(__file__)))

    # ----------------------------------------------------
    # Parse commandline arguments
    # ----------------------------------------------------
    args = commandline.setup_logging_and_parse_arguments(
        blocktitle="Commandline Arguments")

    with logger.LoggingBlock("Source Code", emph=True):
        # ----------------------------------------------------
        # Also archieve source code
        # ----------------------------------------------------
        dst = os.path.join(args.save, "src.zip")
        zipsource.create_zip(filename=os.path.join(args.save, "src.zip"),
                             directory=os.getcwd())
        logging.info("Archieved code: %s" % dst)

    # ----------------------------------------------------
    # Set random seed, possibly on Cuda
    # ----------------------------------------------------
    config.configure_random_seed(args)

    # ----------------------------------------------------
    # Change process title for `top` and `pkill` commands
    # This is more informative in `nvidia-smi`
    # ----------------------------------------------------
    setproctitle.setproctitle(args.proctitle)

    # ------------------------------------------------------
    # Fetch data loaders. Quit if no data loader is present
    # ------------------------------------------------------
    train_loader, validation_loader, inference_loader = config.configure_data_loaders(
        args)

    # -------------------------------------------------------------------------
    # Check whether any dataset could be found
    # -------------------------------------------------------------------------
    success = any(
        loader is not None
        for loader in [train_loader, validation_loader, inference_loader])
    if not success:
        logging.info(
            "No dataset could be loaded successfully. Please check dataset paths!"
        )
        quit()

    # -------------------------------------------------------------------------
    # Configure runtime augmentations
    # -------------------------------------------------------------------------
    training_augmentation, validation_augmentation = config.configure_runtime_augmentations(
        args)

    # ----------------------------------------------------------
    # Configure model and loss.
    # ----------------------------------------------------------
    model_and_loss = config.configure_model_and_loss(args)

    # -----------------------------------------------------------
    # Cuda
    # -----------------------------------------------------------
    with logger.LoggingBlock("Device", emph=True):
        if args.cuda:
            device = torch.device("cuda")
            logging.info("GPU")
        else:
            device = torch.device("cpu")
            logging.info("CPU")

    # ----------------------------------------------------------
    # Configure adversarial attack
    # ----------------------------------------------------------
    attack = config.configure_attack(args)

    # --------------------------------------------------------
    # Print model visualization
    # --------------------------------------------------------
    if args.logging_model_graph:
        with logger.LoggingBlock("Model Graph", emph=True):
            logger.log_module_info(model_and_loss.model)
    if args.logging_loss_graph:
        with logger.LoggingBlock("Loss Graph", emph=True):
            logger.log_module_info(model_and_loss.loss)

    # -------------------------------------------------------------------------
    # Possibly resume from checkpoint
    # -------------------------------------------------------------------------
    checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver(
        args, model_and_loss)
    if checkpoint_stats is not None:
        logging.info("  Checkpoint Statistics:")
        for key, value in checkpoint_stats.items():
            logging.info("    {}: {}".format(key, value))

        # ---------------------------------------------------------------------
        # Set checkpoint stats
        # ---------------------------------------------------------------------
        if args.checkpoint_mode in ["resume_from_best", "resume_from_latest"]:
            args.start_epoch = checkpoint_stats["epoch"]

    # ---------------------------------------------------------------------
    # Checkpoint and save directory
    # ---------------------------------------------------------------------
    with logger.LoggingBlock("Save Directory", emph=True):
        logging.info("Save directory: %s" % args.save)
        if not os.path.exists(args.save):
            os.makedirs(args.save)

    # ----------------------------------------------------------
    # Configure optimizer
    # ----------------------------------------------------------
    optimizer = config.configure_optimizer(args, model_and_loss)

    # ----------------------------------------------------------
    # Configure learning rate
    # ----------------------------------------------------------
    lr_scheduler = config.configure_lr_scheduler(args, optimizer)

    # ------------------------------------------------------------
    # If this is just an evaluation: overwrite savers and epochs
    # ------------------------------------------------------------
    if args.evaluation:
        args.start_epoch = 1
        args.total_epochs = 1
        train_loader = None
        checkpoint_saver = None
        optimizer = None
        lr_scheduler = None

    # ----------------------------------------------------------
    # Cuda optimization
    # ----------------------------------------------------------
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # ----------------------------------------------------------
    # Kickoff training, validation and/or testing
    # ----------------------------------------------------------
    return runtime.exec_runtime(
        args,
        device=device,
        checkpoint_saver=checkpoint_saver,
        model_and_loss=model_and_loss,
        optimizer=optimizer,
        attack=attack,
        lr_scheduler=lr_scheduler,
        train_loader=train_loader,
        validation_loader=validation_loader,
        inference_loader=inference_loader,
        training_augmentation=training_augmentation,
        validation_augmentation=validation_augmentation)
Example #16
0
def exec_runtime(args, checkpoint_saver, model_and_loss, optimizer,
                 lr_scheduler, param_scheduler, train_loader,
                 validation_loader, training_augmentation,
                 validation_augmentation, visualizer):
    # --------------------------------------------------------------------------------
    # Validation schedulers are a bit special:
    # They need special treatment as they want to be called with a validation loss..
    # --------------------------------------------------------------------------------
    validation_scheduler = (lr_scheduler is not None
                            and args.lr_scheduler == "ReduceLROnPlateau")

    # --------------------------------------------------------
    # Log some runtime info
    # --------------------------------------------------------
    with logging.block("Runtime", emph=True):
        logging.value("start_epoch: ", args.start_epoch)
        logging.value("total_epochs: ", args.total_epochs)

    # ---------------------------------------
    # Total progress bar arguments
    # ---------------------------------------
    progressbar_args = {
        "desc": "Total",
        "initial": args.start_epoch - 1,
        "invert_iterations": True,
        "iterable": range(1, args.total_epochs + 1),
        "logging_on_close": True,
        "logging_on_update": True,
        "unit": "ep",
        "track_eta": True
    }

    # --------------------------------------------------------
    # Total progress bar
    # --------------------------------------------------------
    print(''), logging.logbook('')
    total_progress = create_progressbar(**progressbar_args)
    total_progress_stats = {}
    print("\n")

    # -------------------------------------------------k-------
    # Remember validation losses
    # --------------------------------------------------------
    best_validation_losses = None
    store_as_best = None
    if validation_loader is not None:
        num_validation_losses = len(args.validation_keys)
        best_validation_losses = [
            float("inf")
            if args.validation_modes[i] == 'min' else -float("inf")
            for i in range(num_validation_losses)
        ]
        store_as_best = [False for _ in range(num_validation_losses)]

    # ----------------------------------------------------------------
    # Send Telegram message
    # ----------------------------------------------------------------
    logging.telegram(format_telegram_status_update(args, epoch=0))

    avg_loss_dict = {}
    for epoch in range(args.start_epoch, args.total_epochs + 1):

        # --------------------------------
        # Make Epoch %i/%i header message
        # --------------------------------
        epoch_header = "Epoch {}/{}{}{}".format(
            epoch, args.total_epochs, " " * 24,
            format_epoch_header_machine_stats(args))

        with logger.LoggingBlock(epoch_header, emph=True):

            # -------------------------------------------------------------------------------
            # Let TensorBoard know where we are..
            # -------------------------------------------------------------------------------
            summary.set_global_step(epoch)

            # -----------------------------------------------------------------
            # Update standard learning scheduler and get current learning rate
            # -----------------------------------------------------------------
            #  Starting with PyTorch 1.1 the expected validation order is:
            #       optimize(...)
            #       validate(...)
            #       scheduler.step()..

            # ---------------------------------------------------------------------
            # Update parameter schedule before the epoch
            # Note: Parameter schedulers are tuples of (optimizer, schedule)
            # ---------------------------------------------------------------------
            if param_scheduler is not None:
                param_scheduler.step(epoch=epoch)

            # -----------------------------------------------------------------
            # Get current learning rate from either optimizer or scheduler
            # -----------------------------------------------------------------
            lr = args.optimizer_lr if args.optimizer is not None else "None"
            if lr_scheduler is not None:
                lr = [group['lr'] for group in optimizer.param_groups] \
                    if args.optimizer is not None else "None"

            # --------------------------------------------------------
            # Current Epoch header stats
            # --------------------------------------------------------
            logging.info(format_epoch_header_stats(args, lr))

            # -------------------------------------------
            # Create and run a training epoch
            # -------------------------------------------
            if train_loader is not None:
                if visualizer is not None:
                    visualizer.on_epoch_init(lr,
                                             train=True,
                                             epoch=epoch,
                                             total_epochs=args.total_epochs)

                ema_loss_dict = RuntimeEpoch(
                    args,
                    desc="Train",
                    augmentation=training_augmentation,
                    loader=train_loader,
                    model_and_loss=model_and_loss,
                    optimizer=optimizer,
                    visualizer=visualizer).run(train=True)

                if visualizer is not None:
                    visualizer.on_epoch_finished(
                        ema_loss_dict,
                        train=True,
                        epoch=epoch,
                        total_epochs=args.total_epochs)

            # -------------------------------------------
            # Create and run a validation epoch
            # -------------------------------------------
            if validation_loader is not None:
                if visualizer is not None:
                    visualizer.on_epoch_init(lr,
                                             train=False,
                                             epoch=epoch,
                                             total_epochs=args.total_epochs)

                # ---------------------------------------------------
                # Construct holistic recorder for epoch
                # ---------------------------------------------------
                epoch_recorder = configure_holistic_epoch_recorder(
                    args, epoch=epoch, loader=validation_loader)

                with torch.no_grad():
                    avg_loss_dict = RuntimeEpoch(
                        args,
                        desc="Valid",
                        augmentation=validation_augmentation,
                        loader=validation_loader,
                        model_and_loss=model_and_loss,
                        recorder=epoch_recorder,
                        visualizer=visualizer).run(train=False)

                    try:
                        epoch_recorder.add_scalars("evaluation_losses",
                                                   avg_loss_dict)
                    except Exception:
                        pass

                    if visualizer is not None:
                        visualizer.on_epoch_finished(
                            avg_loss_dict,
                            train=False,
                            epoch=epoch,
                            total_epochs=args.total_epochs)

                # ----------------------------------------------------------------
                # Evaluate valdiation losses
                # ----------------------------------------------------------------
                validation_losses = [
                    avg_loss_dict[vkey] for vkey in args.validation_keys
                ]
                for i, (vkey, vmode) in enumerate(
                        zip(args.validation_keys, args.validation_modes)):
                    if vmode == 'min':
                        store_as_best[i] = validation_losses[
                            i] < best_validation_losses[i]
                    else:
                        store_as_best[i] = validation_losses[
                            i] > best_validation_losses[i]
                    if store_as_best[i]:
                        best_validation_losses[i] = validation_losses[i]

                # ----------------------------------------------------------------
                # Update validation scheduler, if one is in place
                # We use the first key in validation keys as the relevant one
                # ----------------------------------------------------------------
                if lr_scheduler is not None:
                    if validation_scheduler:
                        lr_scheduler.step(validation_losses[0], epoch=epoch)
                    else:
                        lr_scheduler.step(epoch=epoch)

                # ----------------------------------------------------------------
                # Also show best loss on total_progress
                # ----------------------------------------------------------------
                total_progress_stats = {
                    "best_" + vkey + "_avg":
                    "%1.4f" % best_validation_losses[i]
                    for i, vkey in enumerate(args.validation_keys)
                }
                total_progress.set_postfix(total_progress_stats)

            # ----------------------------------------------------------------
            # Bump total progress
            # ----------------------------------------------------------------
            total_progress.update()
            print('')

            # ----------------------------------------------------------------
            # Get ETA string for display in loggers
            # ----------------------------------------------------------------
            eta_str = total_progress.eta_str()

            # ----------------------------------------------------------------
            # Send Telegram status udpate
            # ----------------------------------------------------------------
            total_progress_stats['lr'] = format_learning_rate(lr)
            logging.telegram(
                format_telegram_status_update(
                    args,
                    eta_str=eta_str,
                    epoch=epoch,
                    total_progress_stats=total_progress_stats))

            # ----------------------------------------------------------------
            # Update ETA in progress title
            # ----------------------------------------------------------------
            eta_proctitle = "{} finishes in {}".format(args.proctitle, eta_str)
            proctitles.setproctitle(eta_proctitle)

            # ----------------------------------------------------------------
            # Store checkpoint
            # ----------------------------------------------------------------
            if checkpoint_saver is not None and validation_loader is not None:
                checkpoint_saver.save_latest(
                    directory=args.save,
                    model_and_loss=model_and_loss,
                    stats_dict=dict(avg_loss_dict, epoch=epoch),
                    store_as_best=store_as_best,
                    store_prefixes=args.validation_keys)

            # ----------------------------------------------------------------
            # Vertical space between epochs
            # ----------------------------------------------------------------
            print(''), logging.logbook('')

    # ----------------------------------------------------------------
    # Finish up
    # ----------------------------------------------------------------
    logging.telegram_flush()
    total_progress.close()
    logging.info("Finished.")
Example #17
0
def main():

    # Change working directory
    os.chdir(os.path.dirname(os.path.realpath(__file__)))

    # Parse commandline arguments
    args = commandline.setup_logging_and_parse_arguments(
        blocktitle="Commandline Arguments")

    # Set random seed, possibly on Cuda
    config.configure_random_seed(args)

    # DataLoader
    train_loader, validation_loader, inference_loader = config.configure_data_loaders(
        args)
    success = any(
        loader is not None
        for loader in [train_loader, validation_loader, inference_loader])
    if not success:
        logging.info(
            "No dataset could be loaded successfully. Please check dataset paths!"
        )
        quit()

    # Configure data augmentation
    training_augmentation, validation_augmentation = config.configure_runtime_augmentations(
        args)

    # Configure model and loss
    model_and_loss = config.configure_model_and_loss(args)

    # Multi-GPU automation
    with logger.LoggingBlock("Multi GPU", emph=True):
        logging.info("Let's use %d GPUs!" % torch.cuda.device_count())
        model_and_loss._model = torch.nn.DataParallel(model_and_loss._model)

    # Resume from checkpoint if available
    checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver(
        args, model_and_loss)

    # Checkpoint and save directory
    with logger.LoggingBlock("Save Directory", emph=True):
        logging.info("Save directory: %s" % args.save)
        if not os.path.exists(args.save):
            os.makedirs(args.save)

    # Configure optimizer
    optimizer = config.configure_optimizer(args, model_and_loss)

    # Configure learning rate
    lr_scheduler = config.configure_lr_scheduler(args, optimizer)

    # If this is just an evaluation: overwrite savers and epochs
    if args.evaluation:
        args.start_epoch = 1
        args.total_epochs = 1
        train_loader = None
        checkpoint_saver = None
        optimizer = None
        lr_scheduler = None

    # Cuda optimization
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Kickoff training, validation and/or testing
    return runtime.exec_runtime(
        args,
        checkpoint_saver=checkpoint_saver,
        model_and_loss=model_and_loss,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_loader=train_loader,
        validation_loader=validation_loader,
        inference_loader=inference_loader,
        training_augmentation=training_augmentation,
        validation_augmentation=validation_augmentation)
Example #18
0
def exec_runtime(args,
                 checkpoint_saver,
                 model_and_loss,
                 optimizer,
                 lr_scheduler,
                 train_loader,
                 validation_loader,
                 inference_loader,
                 training_augmentation,
                 validation_augmentation):

    # ----------------------------------------------------------------------------------------------
    # Validation schedulers are a bit special:
    # They want to be called with a validation loss..
    # ----------------------------------------------------------------------------------------------
    validation_scheduler = (lr_scheduler is not None and args.lr_scheduler == "ReduceLROnPlateau")

    # --------------------------------------------------------
    # Log some runtime info
    # --------------------------------------------------------
    with logger.LoggingBlock("Runtime", emph=True):
        logging.info("start_epoch: %i" % args.start_epoch)
        logging.info("total_epochs: %i" % args.total_epochs)

    # ---------------------------------------
    # Total progress bar arguments
    # ---------------------------------------
    progressbar_args = {
        "desc": "Progress",
        "initial": args.start_epoch - 1,
        "invert_iterations": True,
        "iterable": range(1, args.total_epochs + 1),
        "logging_on_close": True,
        "logging_on_update": True,
        "postfix": False,
        "unit": "ep"
    }

    # --------------------------------------------------------
    # Total progress bar
    # --------------------------------------------------------
    print(''), logging.logbook('')
    total_progress = create_progressbar(**progressbar_args)
    print("\n")

    # --------------------------------------------------------
    # Remember validation loss
    # --------------------------------------------------------
    best_validation_loss = float("inf") if args.validation_key_minimize else -float("inf")
    store_as_best = False

    for epoch in range(args.start_epoch, args.total_epochs + 1):
        with logger.LoggingBlock("Epoch %i/%i" % (epoch, args.total_epochs), emph=True):

            # --------------------------------------------------------
            # Update standard learning scheduler
            # --------------------------------------------------------
            if lr_scheduler is not None and not validation_scheduler:
                lr_scheduler.step(epoch)

            # --------------------------------------------------------
            # Always report learning rate
            # --------------------------------------------------------
            if lr_scheduler is None:
                logging.info("lr: %s" % format_learning_rate(args.optimizer_lr))
            else:
                logging.info("lr: %s" % format_learning_rate(lr_scheduler.get_lr()))

            # -------------------------------------------
            # Create and run a training epoch
            # -------------------------------------------
            if train_loader is not None:
                avg_loss_dict = TrainingEpoch(
                    args,
                    desc="   Train",
                    model_and_loss=model_and_loss,
                    optimizer=optimizer,
                    loader=train_loader,
                    augmentation=training_augmentation,
                    checkpoint_saver=checkpoint_saver,
                    checkpoint_args={
                        'directory': args.save,
                        'model_and_loss': model_and_loss,
                        'stats_dict': dict({'epe': 0, 'F1': 0}, epoch=epoch),
                        'store_as_best': False
                    }
                ).run()

            # -------------------------------------------
            # Create and run a validation epoch
            # -------------------------------------------
            if validation_loader is not None:

                # ---------------------------------------------------
                # Construct holistic recorder for epoch
                # ---------------------------------------------------
                avg_loss_dict = EvaluationEpoch(
                    args,
                    desc="Validate",
                    model_and_loss=model_and_loss,
                    loader=validation_loader,
                    augmentation=validation_augmentation).run()

                # ----------------------------------------------------------------
                # Evaluate whether this is the best validation_loss
                # ----------------------------------------------------------------
                validation_loss = avg_loss_dict[args.validation_key]
                if args.validation_key_minimize:
                    store_as_best = validation_loss < best_validation_loss
                else:
                    store_as_best = validation_loss > best_validation_loss
                if store_as_best:
                    best_validation_loss = validation_loss

                # ----------------------------------------------------------------
                # Update validation scheduler, if one is in place
                # ----------------------------------------------------------------
                if lr_scheduler is not None and validation_scheduler:
                    lr_scheduler.step(validation_loss, epoch=epoch)

            # ----------------------------------------------------------------
            # Also show best loss on total_progress
            # ----------------------------------------------------------------
            total_progress_stats = {
                "best_" + args.validation_key + "_avg": "%1.4f" % best_validation_loss
            }
            total_progress.set_postfix(total_progress_stats)

            # ----------------------------------------------------------------
            # Bump total progress
            # ----------------------------------------------------------------
            total_progress.update()
            print('')

            # ----------------------------------------------------------------
            # Store checkpoint
            # ----------------------------------------------------------------
            if checkpoint_saver is not None:
                checkpoint_saver.save_latest(
                    directory=args.save,
                    model_and_loss=model_and_loss,
                    stats_dict=dict(avg_loss_dict, epoch=epoch),
                    store_as_best=store_as_best)

            # ----------------------------------------------------------------
            # Vertical space between epochs
            # ----------------------------------------------------------------
            print(''), logging.logbook('')

            # quit after completing epoch
            quit()

    # ----------------------------------------------------------------
    # Finish
    # ----------------------------------------------------------------
    total_progress.close()
    logging.info("Finished.")