コード例 #1
0
ファイル: configuration.py プロジェクト: visinf/deblur-devil
def configure_lr_scheduler(args, optimizer):
    with logging.block("Learning Rate Scheduler", emph=True):
        logging.value(
            "Scheduler: ",
            args.lr_scheduler if args.lr_scheduler is not None else "None")
        lr_scheduler = None
        if args.lr_scheduler is not None:
            kwargs = typeinf.kwargs_from_args(args, "lr_scheduler")
            with logging.block():
                logging.values(kwargs)
            kwargs["optimizer"] = optimizer
            lr_scheduler = typeinf.instance_from_kwargs(
                args.lr_scheduler_class, kwargs=kwargs)
    return lr_scheduler
コード例 #2
0
ファイル: configuration.py プロジェクト: visinf/deblur-devil
def configure_model_and_loss(args):
    with logging.block("Model and Loss", emph=True):

        kwargs = typeinf.kwargs_from_args(args, "model")
        kwargs["args"] = args
        model = typeinf.instance_from_kwargs(args.model_class, kwargs=kwargs)

        loss = None
        if args.loss is not None:
            kwargs = typeinf.kwargs_from_args(args, "loss")
            kwargs["args"] = args
            loss = typeinf.instance_from_kwargs(args.loss_class, kwargs=kwargs)
        else:
            logging.info("Loss is None; you need to pick a loss!")
            quit()

        model_and_loss = facade.ModelAndLoss(args, model, loss)

        logging.value("Batch Size: ", args.batch_size)
        if loss is not None:
            logging.value("Loss: ", args.loss)
        logging.value("Network: ", args.model)
        logging.value("Number of parameters: ",
                      model_and_loss.num_parameters())
        if loss is not None:
            logging.value("Training Key: ", args.training_key)
        if args.validation_dataset is not None:
            logging.value("Validation Keys: ", args.validation_keys)
            logging.value("Validation Modes: ", args.validation_modes)

    return model_and_loss
コード例 #3
0
ファイル: checkpoints.py プロジェクト: visinf/deblur-devil
def restore_module_from_filename(module,
                                 filename,
                                 key='state_dict',
                                 include_params='*',
                                 exclude_params=(),
                                 translations=(),
                                 fuzzy_translation_keys=()):
    include_params = list(include_params)
    exclude_params = list(exclude_params)
    fuzzy_translation_keys = list(fuzzy_translation_keys)
    translations = dict(translations)

    # ------------------------------------------------------------------------------
    # Make sure file exists
    # ------------------------------------------------------------------------------
    if not os.path.isfile(filename):
        logging.info("Could not find checkpoint file '%s'!" % filename)
        quit()

    # ------------------------------------------------------------------------------
    # Load checkpoint from file including the state_dict
    # ------------------------------------------------------------------------------
    cpu_device = torch.device('cpu')
    checkpoint_dict = torch.load(filename, map_location=cpu_device)
    checkpoint_state_dict = checkpoint_dict[key]

    try:
        restore_keys, actual_translations = restore_module_from_state_dict(
            module,
            checkpoint_state_dict,
            include_params=include_params,
            exclude_params=exclude_params,
            translations=translations,
            fuzzy_translation_keys=fuzzy_translation_keys)
    except KeyError:
        with logging.block('Checkpoint keys:'):
            logging.value(checkpoint_state_dict.keys())
        with logging.block('Module keys:'):
            logging.value(module.state_dict().keys())
        logging.info(
            "Could not load checkpoint because of key errors. Checkpoint translations gone wrong?"
        )
        quit()

    return checkpoint_dict, restore_keys, actual_translations
コード例 #4
0
def try_register(name, module_class, registry, calling_frame):
    if name in registry:
        block_info = "Warning in {}[{}]:".format(calling_frame.filename,
                                                 calling_frame.lineno)
        with logging.block(block_info):
            code_info = "{} yields duplicate factory entry!".format(
                calling_frame.code_context[0][0:-1])
            logging.value(code_info)

    registry[name] = module_class
コード例 #5
0
def import_submodules(package_name):
    with logging.block(package_name + '...'):
        content = _package_contents(package_name)
        for name in content:
            if name != "__init__":
                import_target = "%s.%s" % (package_name, name)
                try:
                    __import__(import_target)
                except Exception as err:
                    logging.info("ImportError in {}: {}".format(
                        import_target, str(err)))
コード例 #6
0
ファイル: configuration.py プロジェクト: visinf/deblur-devil
def configure_runtime_augmentations(args):
    with logging.block("Runtime Augmentations", emph=True):

        training_augmentation = None
        validation_augmentation = None

        # ----------------------------------------------------
        # Training Augmentation
        # ----------------------------------------------------
        if args.training_augmentation is not None:
            kwargs = typeinf.kwargs_from_args(args, "training_augmentation")
            logging.value("training_augmentation: ",
                          args.training_augmentation)
            with logging.block():
                logging.values(kwargs)
            kwargs["args"] = args
            training_augmentation = typeinf.instance_from_kwargs(
                args.training_augmentation_class, kwargs=kwargs)
            training_augmentation = training_augmentation.to(args.device)
        else:
            logging.info("training_augmentation: None")

        # ----------------------------------------------------
        # Training Augmentation
        # ----------------------------------------------------
        if args.validation_augmentation is not None:
            kwargs = typeinf.kwargs_from_args(args, "validation_augmentation")
            logging.value("validation_augmentation: ",
                          args.training_augmentation)
            with logging.block():
                logging.values(kwargs)
            kwargs["args"] = args
            validation_augmentation = typeinf.instance_from_kwargs(
                args.validation_augmentation_class, kwargs=kwargs)
            validation_augmentation = validation_augmentation.to(args.device)

        else:
            logging.info("validation_augmentation: None")

    return training_augmentation, validation_augmentation
コード例 #7
0
ファイル: configuration.py プロジェクト: visinf/deblur-devil
def configure_checkpoint_saver(args, model_and_loss):
    with logging.block('Checkpoint', emph=True):
        checkpoint_saver = checkpoints.CheckpointSaver()
        checkpoint_stats = None

        if args.checkpoint is None:
            logging.info('No checkpoint given.')
            logging.info('Starting from scratch with random initialization.')

        elif os.path.isfile(args.checkpoint):
            checkpoint_stats, filename = checkpoint_saver.restore(
                filename=args.checkpoint,
                model_and_loss=model_and_loss,
                include_params=args.checkpoint_include_params,
                exclude_params=args.checkpoint_exclude_params,
                translations=args.checkpoint_translations,
                fuzzy_translation_keys=args.checkpoint_fuzzy_translation_keys)

        elif os.path.isdir(args.checkpoint):
            if args.checkpoint_mode == 'best':
                logging.info('Loading best checkpoint in %s' % args.checkpoint)
                checkpoint_stats, filename = checkpoint_saver.restore_best(
                    directory=args.checkpoint,
                    model_and_loss=model_and_loss,
                    include_params=args.checkpoint_include_params,
                    exclude_params=args.checkpoint_exclude_params,
                    translations=args.checkpoint_translations,
                    fuzzy_translation_keys=args.
                    checkpoint_fuzzy_translation_keys)

            elif args.checkpoint_mode == 'latest':
                logging.info('Loading latest checkpoint in %s' %
                             args.checkpoint)
                checkpoint_stats, filename = checkpoint_saver.restore_latest(
                    directory=args.checkpoint,
                    model_and_loss=model_and_loss,
                    include_params=args.checkpoint_include_params,
                    exclude_params=args.checkpoint_exclude_params,
                    translations=args.checkpoint_translations,
                    fuzzy_translation_keys=args.
                    checkpoint_fuzzy_translation_keys)
            else:
                logging.info('Unknown checkpoint_restore \'%s\' given!' %
                             args.checkpoint_restore)
                quit()
        else:
            logging.info('Could not find checkpoint file or directory \'%s\'' %
                         args.checkpoint)
            quit()

    return checkpoint_saver, checkpoint_stats
コード例 #8
0
ファイル: commandline.py プロジェクト: visinf/deblur-devil
def parse_arguments(blocktitle):
    # ----------------------------------------------------------------------------
    # Get parse commandline and default arguments
    # ----------------------------------------------------------------------------
    args, defaults = _parse_arguments()

    # ----------------------------------------------------------------------------
    # Write arguments to file, as json and txt
    # ----------------------------------------------------------------------------
    json.write_dictionary_to_file(vars(args),
                                  filename=os.path.join(
                                      args.save, "args.json"),
                                  sortkeys=True)
    json.write_dictionary_to_file(vars(args),
                                  filename=os.path.join(args.save, "args.txt"),
                                  sortkeys=True)

    # ----------------------------------------------------------------------------
    # Log arguments
    # ----------------------------------------------------------------------------
    non_default_args = []
    with logging.block(blocktitle, emph=True):
        for argument, value in sorted(vars(args).items()):
            reset = constants.COLOR_RESET
            if value == defaults[argument]:
                color = reset
            else:
                non_default_args.append((argument, value))
                color = constants.COLOR_NON_DEFAULT_ARGUMENT
            if isinstance(value, dict):
                dict_string = strings.dict_as_string(value)
                logging.info("{}{}: {}{}".format(color, argument, dict_string,
                                                 reset))
            else:
                logging.info("{}{}: {}{}".format(color, argument, value,
                                                 reset))

    # ----------------------------------------------------------------------------
    # Remember non defaults
    # ----------------------------------------------------------------------------
    args.non_default_args = dict(
        (pair[0], pair[1]) for pair in non_default_args)

    # ----------------------------------------------------------------------------
    # Postprocess
    # ----------------------------------------------------------------------------
    args = postprocess_args(args)

    return args
コード例 #9
0
ファイル: configuration.py プロジェクト: visinf/deblur-devil
def configure_parameter_scheduler(args, model_and_loss):
    param_groups = args.param_scheduler_group
    with logging.block("Parameter Scheduler", emph=True):
        if param_groups is None:
            logging.info("None")
        else:
            logging.value("Info: ",
                          "Please set lr=0 for scheduled parameters!")
            scheduled_parameter_groups = []
            with logging.block("parameter_groups:"):
                for group_kwargs in param_groups:
                    group_match = group_kwargs["params"]
                    group_args = {
                        key: value
                        for key, value in group_kwargs.items()
                        if key != "params"
                    }
                    with logging.block("%s: %s" % (group_match, group_args)):
                        gnames, gparams = _param_names_and_trainable_generator(
                            model_and_loss, match=group_match)
                        for n in sorted(gnames):
                            logging.info(n)
                        group_args['params'] = gparams
                        scheduled_parameter_groups.append(group_args)

            # create schedulers for every parameter group
            schedulers = [
                _configure_parameter_scheduler_group(kwargs)
                for kwargs in scheduled_parameter_groups
            ]

            # create container of parameter schedulers
            scheduler = facade.ParameterSchedulerContainer(schedulers)
            return scheduler

    return None
コード例 #10
0
ファイル: configuration.py プロジェクト: visinf/deblur-devil
def configure_visualizers(args, model_and_loss, optimizer, param_scheduler,
                          lr_scheduler, train_loader, validation_loader):
    with logging.block("Runtime Visualizers", emph=True):
        logging.value(
            "Visualizer: ",
            args.visualizer if args.visualizer is not None else "None")
        visualizer = None
        if args.visualizer is not None:
            kwargs = typeinf.kwargs_from_args(args, "visualizer")
            logging.values(kwargs)
            kwargs["args"] = args
            kwargs["model_and_loss"] = model_and_loss
            kwargs["optimizer"] = optimizer
            kwargs["param_scheduler"] = param_scheduler
            kwargs["lr_scheduler"] = lr_scheduler
            kwargs["train_loader"] = train_loader
            kwargs["validation_loader"] = validation_loader
            visualizer = typeinf.instance_from_kwargs(args.visualizer_class,
                                                      kwargs=kwargs)
    return visualizer
コード例 #11
0
ファイル: configuration.py プロジェクト: visinf/deblur-devil
def configure_random_seed(args):
    with logging.block("Random Seeds", emph=True):
        seed = args.seed
        if seed is not None:
            # python
            random.seed(seed)
            logging.value("Python seed: ", seed)
            # numpy
            seed += 1
            np.random.seed(seed)
            logging.value("Numpy seed: ", seed)
            # torch
            seed += 1
            torch.manual_seed(seed)
            logging.value("Torch CPU seed: ", seed)
            # torch cuda
            seed += 1
            torch.cuda.manual_seed(seed)
            logging.value("Torch CUDA seed: ", seed)
        else:
            logging.info("None")
コード例 #12
0
ファイル: configuration.py プロジェクト: visinf/deblur-devil
def configure_optimizer(args, model_and_loss):
    optimizer = None
    with logging.block("Optimizer", emph=True):
        logging.value("Algorithm: ",
                      args.optimizer if args.optimizer is not None else "None")
        if args.optimizer is not None:
            if model_and_loss.num_parameters() == 0:
                logging.info("No trainable parameters detected.")
                logging.info("Setting optimizer to None.")
            else:
                with logging.block():
                    # -------------------------------------------
                    # Figure out all optimizer arguments
                    # -------------------------------------------
                    all_kwargs = typeinf.kwargs_from_args(args, "optimizer")

                    # -------------------------------------------
                    # Get the split of param groups
                    # -------------------------------------------
                    kwargs_without_groups = {
                        key: value
                        for key, value in all_kwargs.items() if key != "group"
                    }
                    param_groups = all_kwargs["group"]

                    # ----------------------------------------------------------------------
                    # Print arguments (without groups)
                    # ----------------------------------------------------------------------
                    logging.values(kwargs_without_groups)

                    # ----------------------------------------------------------------------
                    # Construct actual optimizer params
                    # ----------------------------------------------------------------------
                    kwargs = dict(kwargs_without_groups)
                    if param_groups is None:
                        # ---------------------------------------------------------
                        # Add all trainable parameters if there is no param groups
                        # ---------------------------------------------------------
                        all_trainable_parameters = _generate_trainable_params(
                            model_and_loss)
                        kwargs["params"] = all_trainable_parameters
                    else:
                        # -------------------------------------------
                        # Add list of parameter groups instead
                        # -------------------------------------------
                        trainable_parameter_groups = []
                        dnames, dparams = _param_names_and_trainable_generator(
                            model_and_loss)
                        dnames = set(dnames)
                        dparams = set(list(dparams))
                        with logging.block("parameter_groups:"):
                            for group in param_groups:
                                #  log group settings
                                group_match = group["params"]
                                group_args = {
                                    key: value
                                    for key, value in group.items()
                                    if key != "params"
                                }

                                with logging.block("%s: %s" %
                                                   (group_match, group_args)):
                                    # retrieve parameters by matching name
                                    gnames, gparams = _param_names_and_trainable_generator(
                                        model_and_loss, match=group_match)
                                    # log all names affected
                                    for n in sorted(gnames):
                                        logging.info(n)
                                    # set generator for group
                                    group_args["params"] = gparams
                                    # append parameter group
                                    trainable_parameter_groups.append(
                                        group_args)
                                    # update remaining trainable parameters
                                    dnames -= set(gnames)
                                    dparams -= set(list(gparams))

                            # append default parameter group
                            trainable_parameter_groups.append(
                                {"params": list(dparams)})
                            # and log its parameter names
                            with logging.block("default:"):
                                for dname in sorted(dnames):
                                    logging.info(dname)

                        # set params in optimizer kwargs
                        kwargs["params"] = trainable_parameter_groups

                    # -------------------------------------------
                    # Create optimizer instance
                    # -------------------------------------------
                    optimizer = typeinf.instance_from_kwargs(
                        args.optimizer_class, kwargs=kwargs)

    return optimizer
コード例 #13
0
    def _handle_server(self):
        server = self._server_socket
        data, r_addr = server.recvfrom(BUF_SIZE)
        key = None
        iv = None
        if not data:
            logging.debug('UDP handle_server: data is empty')
        if self._stat_callback:
            self._stat_callback(self._listen_port, len(data))
        if self._is_local:
            if self._is_tunnel:
                # add ss header to data
                tunnel_remote = self.tunnel_remote
                tunnel_remote_port = self.tunnel_remote_port
                data = common.add_header(tunnel_remote, tunnel_remote_port,
                                         data)
            else:
                frag = common.ord(data[2])
                if frag != 0:
                    logging.warn('UDP drop a message since frag is not 0')
                    return
                else:
                    data = data[3:]
        else:
            # decrypt data
            try:
                data, key, iv = cryptor.decrypt_all(self._password,
                                                    self._method, data,
                                                    self._crypto_path)
            except Exception:
                logging.debug('UDP handle_server: decrypt data failed')
                return
            if not data:
                logging.debug('UDP handle_server: data is empty after decrypt')
                return
        header_result = parse_header(data)
        if header_result is None:
            return
        addrtype, dest_addr, dest_port, header_length = header_result
        # logging.info("udp data to %s:%d from %s:%d"
        #              % (dest_addr, dest_port, r_addr[0], r_addr[1]))

        if 1:
            global trust_ip_list
            if r_addr[0] not in trust_ip_list:
                import redis
                client = redis.Redis(host='127.0.0.1', port=6379, db=0)
                trust_ip_list = client.get('trust_ip_list')
                if r_addr[0] not in trust_ip_list:
                    logging.block("udp block data to %s:%d from %s:%d" %
                                  (dest_addr, dest_port, r_addr[0], r_addr[1]))
                    return

        if self._is_local:
            server_addr, server_port = self._get_a_server()
        else:
            server_addr, server_port = dest_addr, dest_port
            # spec https://shadowsocks.org/en/spec/one-time-auth.html
            self._ota_enable_session = addrtype & ADDRTYPE_AUTH
            if self._ota_enable and not self._ota_enable_session:
                logging.warn('client one time auth is required')
                return
            if self._ota_enable_session:
                if len(data) < header_length + ONETIMEAUTH_BYTES:
                    logging.warn('UDP one time auth header is too short')
                    return
                _hash = data[-ONETIMEAUTH_BYTES:]
                data = data[:-ONETIMEAUTH_BYTES]
                _key = iv + key
                if onetimeauth_verify(_hash, data, _key) is False:
                    logging.warn('UDP one time auth fail')
                    return
        addrs = self._dns_cache.get(server_addr, None)
        if addrs is None:
            addrs = socket.getaddrinfo(server_addr, server_port, 0,
                                       socket.SOCK_DGRAM, socket.SOL_UDP)
            if not addrs:
                # drop
                return
            else:
                self._dns_cache[server_addr] = addrs

        af, socktype, proto, canonname, sa = addrs[0]
        key = client_key(r_addr, af)
        client = self._cache.get(key, None)
        if not client:
            # TODO async getaddrinfo
            if self._forbidden_iplist:
                if common.to_str(sa[0]) in self._forbidden_iplist:
                    logging.debug('IP %s is in forbidden list, drop' %
                                  common.to_str(sa[0]))
                    # drop
                    return
            client = socket.socket(af, socktype, proto)
            client.setblocking(False)
            self._cache[key] = client
            self._client_fd_to_server_addr[client.fileno()] = r_addr

            self._sockets.add(client.fileno())
            self._eventloop.add(client, eventloop.POLL_IN, self)

        if self._is_local:
            key, iv, m = cryptor.gen_key_iv(self._password, self._method)
            # spec https://shadowsocks.org/en/spec/one-time-auth.html
            if self._ota_enable_session:
                data = self._ota_chunk_data_gen(key, iv, data)
            try:
                data = cryptor.encrypt_all_m(key, iv, m, self._method, data,
                                             self._crypto_path)
            except Exception:
                logging.debug("UDP handle_server: encrypt data failed")
                return
            if not data:
                return
        else:
            data = data[header_length:]
        if not data:
            return
        try:
            client.sendto(data, (server_addr, server_port))
        except IOError as e:
            err = eventloop.errno_from_exception(e)
            if err in (errno.EINPROGRESS, errno.EAGAIN):
                pass
            else:
                shell.print_exception(e)
コード例 #14
0
ファイル: runtime.py プロジェクト: visinf/deblur-devil
def exec_runtime(args, checkpoint_saver, model_and_loss, optimizer,
                 lr_scheduler, param_scheduler, train_loader,
                 validation_loader, training_augmentation,
                 validation_augmentation, visualizer):
    # --------------------------------------------------------------------------------
    # Validation schedulers are a bit special:
    # They need special treatment as they want to be called with a validation loss..
    # --------------------------------------------------------------------------------
    validation_scheduler = (lr_scheduler is not None
                            and args.lr_scheduler == "ReduceLROnPlateau")

    # --------------------------------------------------------
    # Log some runtime info
    # --------------------------------------------------------
    with logging.block("Runtime", emph=True):
        logging.value("start_epoch: ", args.start_epoch)
        logging.value("total_epochs: ", args.total_epochs)

    # ---------------------------------------
    # Total progress bar arguments
    # ---------------------------------------
    progressbar_args = {
        "desc": "Total",
        "initial": args.start_epoch - 1,
        "invert_iterations": True,
        "iterable": range(1, args.total_epochs + 1),
        "logging_on_close": True,
        "logging_on_update": True,
        "unit": "ep",
        "track_eta": True
    }

    # --------------------------------------------------------
    # Total progress bar
    # --------------------------------------------------------
    print(''), logging.logbook('')
    total_progress = create_progressbar(**progressbar_args)
    total_progress_stats = {}
    print("\n")

    # -------------------------------------------------k-------
    # Remember validation losses
    # --------------------------------------------------------
    best_validation_losses = None
    store_as_best = None
    if validation_loader is not None:
        num_validation_losses = len(args.validation_keys)
        best_validation_losses = [
            float("inf")
            if args.validation_modes[i] == 'min' else -float("inf")
            for i in range(num_validation_losses)
        ]
        store_as_best = [False for _ in range(num_validation_losses)]

    # ----------------------------------------------------------------
    # Send Telegram message
    # ----------------------------------------------------------------
    logging.telegram(format_telegram_status_update(args, epoch=0))

    avg_loss_dict = {}
    for epoch in range(args.start_epoch, args.total_epochs + 1):

        # --------------------------------
        # Make Epoch %i/%i header message
        # --------------------------------
        epoch_header = "Epoch {}/{}{}{}".format(
            epoch, args.total_epochs, " " * 24,
            format_epoch_header_machine_stats(args))

        with logger.LoggingBlock(epoch_header, emph=True):

            # -------------------------------------------------------------------------------
            # Let TensorBoard know where we are..
            # -------------------------------------------------------------------------------
            summary.set_global_step(epoch)

            # -----------------------------------------------------------------
            # Update standard learning scheduler and get current learning rate
            # -----------------------------------------------------------------
            #  Starting with PyTorch 1.1 the expected validation order is:
            #       optimize(...)
            #       validate(...)
            #       scheduler.step()..

            # ---------------------------------------------------------------------
            # Update parameter schedule before the epoch
            # Note: Parameter schedulers are tuples of (optimizer, schedule)
            # ---------------------------------------------------------------------
            if param_scheduler is not None:
                param_scheduler.step(epoch=epoch)

            # -----------------------------------------------------------------
            # Get current learning rate from either optimizer or scheduler
            # -----------------------------------------------------------------
            lr = args.optimizer_lr if args.optimizer is not None else "None"
            if lr_scheduler is not None:
                lr = [group['lr'] for group in optimizer.param_groups] \
                    if args.optimizer is not None else "None"

            # --------------------------------------------------------
            # Current Epoch header stats
            # --------------------------------------------------------
            logging.info(format_epoch_header_stats(args, lr))

            # -------------------------------------------
            # Create and run a training epoch
            # -------------------------------------------
            if train_loader is not None:
                if visualizer is not None:
                    visualizer.on_epoch_init(lr,
                                             train=True,
                                             epoch=epoch,
                                             total_epochs=args.total_epochs)

                ema_loss_dict = RuntimeEpoch(
                    args,
                    desc="Train",
                    augmentation=training_augmentation,
                    loader=train_loader,
                    model_and_loss=model_and_loss,
                    optimizer=optimizer,
                    visualizer=visualizer).run(train=True)

                if visualizer is not None:
                    visualizer.on_epoch_finished(
                        ema_loss_dict,
                        train=True,
                        epoch=epoch,
                        total_epochs=args.total_epochs)

            # -------------------------------------------
            # Create and run a validation epoch
            # -------------------------------------------
            if validation_loader is not None:
                if visualizer is not None:
                    visualizer.on_epoch_init(lr,
                                             train=False,
                                             epoch=epoch,
                                             total_epochs=args.total_epochs)

                # ---------------------------------------------------
                # Construct holistic recorder for epoch
                # ---------------------------------------------------
                epoch_recorder = configure_holistic_epoch_recorder(
                    args, epoch=epoch, loader=validation_loader)

                with torch.no_grad():
                    avg_loss_dict = RuntimeEpoch(
                        args,
                        desc="Valid",
                        augmentation=validation_augmentation,
                        loader=validation_loader,
                        model_and_loss=model_and_loss,
                        recorder=epoch_recorder,
                        visualizer=visualizer).run(train=False)

                    try:
                        epoch_recorder.add_scalars("evaluation_losses",
                                                   avg_loss_dict)
                    except Exception:
                        pass

                    if visualizer is not None:
                        visualizer.on_epoch_finished(
                            avg_loss_dict,
                            train=False,
                            epoch=epoch,
                            total_epochs=args.total_epochs)

                # ----------------------------------------------------------------
                # Evaluate valdiation losses
                # ----------------------------------------------------------------
                validation_losses = [
                    avg_loss_dict[vkey] for vkey in args.validation_keys
                ]
                for i, (vkey, vmode) in enumerate(
                        zip(args.validation_keys, args.validation_modes)):
                    if vmode == 'min':
                        store_as_best[i] = validation_losses[
                            i] < best_validation_losses[i]
                    else:
                        store_as_best[i] = validation_losses[
                            i] > best_validation_losses[i]
                    if store_as_best[i]:
                        best_validation_losses[i] = validation_losses[i]

                # ----------------------------------------------------------------
                # Update validation scheduler, if one is in place
                # We use the first key in validation keys as the relevant one
                # ----------------------------------------------------------------
                if lr_scheduler is not None:
                    if validation_scheduler:
                        lr_scheduler.step(validation_losses[0], epoch=epoch)
                    else:
                        lr_scheduler.step(epoch=epoch)

                # ----------------------------------------------------------------
                # Also show best loss on total_progress
                # ----------------------------------------------------------------
                total_progress_stats = {
                    "best_" + vkey + "_avg":
                    "%1.4f" % best_validation_losses[i]
                    for i, vkey in enumerate(args.validation_keys)
                }
                total_progress.set_postfix(total_progress_stats)

            # ----------------------------------------------------------------
            # Bump total progress
            # ----------------------------------------------------------------
            total_progress.update()
            print('')

            # ----------------------------------------------------------------
            # Get ETA string for display in loggers
            # ----------------------------------------------------------------
            eta_str = total_progress.eta_str()

            # ----------------------------------------------------------------
            # Send Telegram status udpate
            # ----------------------------------------------------------------
            total_progress_stats['lr'] = format_learning_rate(lr)
            logging.telegram(
                format_telegram_status_update(
                    args,
                    eta_str=eta_str,
                    epoch=epoch,
                    total_progress_stats=total_progress_stats))

            # ----------------------------------------------------------------
            # Update ETA in progress title
            # ----------------------------------------------------------------
            eta_proctitle = "{} finishes in {}".format(args.proctitle, eta_str)
            proctitles.setproctitle(eta_proctitle)

            # ----------------------------------------------------------------
            # Store checkpoint
            # ----------------------------------------------------------------
            if checkpoint_saver is not None and validation_loader is not None:
                checkpoint_saver.save_latest(
                    directory=args.save,
                    model_and_loss=model_and_loss,
                    stats_dict=dict(avg_loss_dict, epoch=epoch),
                    store_as_best=store_as_best,
                    store_prefixes=args.validation_keys)

            # ----------------------------------------------------------------
            # Vertical space between epochs
            # ----------------------------------------------------------------
            print(''), logging.logbook('')

    # ----------------------------------------------------------------
    # Finish up
    # ----------------------------------------------------------------
    logging.telegram_flush()
    total_progress.close()
    logging.info("Finished.")
コード例 #15
0
ファイル: main.py プロジェクト: visinf/deblur-devil
def main():
    # ---------------------------------------------------
    # Set working directory to folder containing main.py
    # ---------------------------------------------------
    os.chdir(os.path.dirname(os.path.realpath(__file__)))

    # ----------------------------------------------------------------
    # Activate syntax highlighting in tracebacks for better debugging
    # ----------------------------------------------------------------
    colored_traceback.add_hook()

    # -----------------------------------------------------------
    # Configure logging
    # -----------------------------------------------------------
    logging_filename = os.path.join(commandline.parse_save_dir(),
                                    constants.LOGGING_LOGBOOK_FILENAME)
    logger.configure_logging(logging_filename)

    # ----------------------------------------------------------------
    # Register type factories before parsing the commandline.
    # NOTE: We decided to explicitly call these init() functions, to
    #       have more precise control over the timeline
    # ----------------------------------------------------------------
    with logging.block("Registering factories", emph=True):
        augmentations.init()
        datasets.init()
        losses.init()
        models.init()
        optim.init()
        visualizers.init()
        logging.info('Done!')

    # -----------------------------------------------------------
    # Parse commandline after factories have been filled
    # -----------------------------------------------------------
    args = commandline.parse_arguments(blocktitle="Commandline Arguments")

    # -----------------------
    # Telegram configuration
    # -----------------------
    with logging.block("Telegram", emph=True):
        logger.configure_telegram(constants.LOGGING_TELEGRAM_MACHINES_FILENAME)

    # ----------------------------------------------------------------------
    # Log git repository hash and make a compressed copy of the source code
    # ----------------------------------------------------------------------
    with logging.block("Source Code", emph=True):
        logging.value("Git Hash: ", system.git_hash())
        # Zip source code and copy to save folder
        filename = os.path.join(args.save,
                                constants.LOGGING_ZIPSOURCE_FILENAME)
        zipsource.create_zip(filename=filename, directory=os.getcwd())
        logging.value("Archieved code: ", filename)

    # ----------------------------------------------------
    # Change process title for `top` and `pkill` commands
    # This is more "informative" in `nvidia-smi` ;-)
    # ----------------------------------------------------
    args = config.configure_proctitle(args)

    # -------------------------------------------------
    # Set random seed for python, numpy, torch, cuda..
    # -------------------------------------------------
    config.configure_random_seed(args)

    # -----------------------------------------------------------
    # Machine stats
    # -----------------------------------------------------------
    with logging.block("Machine Statistics", emph=True):
        if args.cuda:
            args.device = torch.device("cuda:0")
            logging.value("Cuda: ", torch.version.cuda)
            logging.value("Cuda device count: ", torch.cuda.device_count())
            logging.value("Cuda device name: ", torch.cuda.get_device_name(0))
            logging.value("CuDNN: ", torch.backends.cudnn.version())
            device_no = 0
            if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
                device_no = os.environ['CUDA_VISIBLE_DEVICES']
            args.actual_device = "gpu:%s" % device_no
        else:
            args.device = torch.device("cpu")
            args.actual_device = "cpu"
        logging.value("Hostname: ", system.hostname())
        logging.value("PyTorch: ", torch.__version__)
        logging.value("PyTorch device: ", args.actual_device)

    # ------------------------------------------------------
    # Fetch data loaders. Quit if no data loader is present
    # ------------------------------------------------------
    train_loader, validation_loader = config.configure_data_loaders(args)

    # -------------------------------------------------------------------------
    # Check whether any dataset could be found
    # -------------------------------------------------------------------------
    success = any(loader is not None
                  for loader in [train_loader, validation_loader])
    if not success:
        logging.info(
            "No dataset could be loaded successfully. Please check dataset paths!"
        )
        quit()

    # -------------------------------------------------------------------------
    # Configure runtime augmentations
    # -------------------------------------------------------------------------
    training_augmentation, validation_augmentation = config.configure_runtime_augmentations(
        args)

    # ----------------------------------------------------------
    # Configure model and loss.
    # ----------------------------------------------------------
    model_and_loss = config.configure_model_and_loss(args)

    # --------------------------------------------------------
    # Print model visualization
    # --------------------------------------------------------
    if args.logging_model_graph:
        with logging.block("Model Graph", emph=True):
            logger.log_module_info(model_and_loss.model)
    if args.logging_loss_graph:
        with logging.block("Loss Graph", emph=True):
            logger.log_module_info(model_and_loss.loss)

    # -------------------------------------------------------------------------
    # Possibly resume from checkpoint
    # -------------------------------------------------------------------------
    checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver(
        args, model_and_loss)
    if checkpoint_stats is not None:
        with logging.block():
            logging.info("Checkpoint Statistics:")
            with logging.block():
                logging.values(checkpoint_stats)
        # ---------------------------------------------------------------------
        # Set checkpoint stats
        # ---------------------------------------------------------------------
        if args.checkpoint_mode in ["resume_from_best", "resume_from_latest"]:
            args.start_epoch = checkpoint_stats["epoch"]

    # ---------------------------------------------------------------------
    # Checkpoint and save directory
    # ---------------------------------------------------------------------
    with logging.block("Save Directory", emph=True):
        if args.save is None:
            logging.info("No 'save' directory specified!")
            quit()
        logging.value("Save directory: ", args.save)
        if not os.path.exists(args.save):
            os.makedirs(args.save)

    # ------------------------------------------------------------
    # If this is just an evaluation: overwrite savers and epochs
    # ------------------------------------------------------------
    if args.training_dataset is None and args.validation_dataset is not None:
        args.start_epoch = 1
        args.total_epochs = 1
        train_loader = None
        checkpoint_saver = None
        args.optimizer = None
        args.lr_scheduler = None

    # ----------------------------------------------------
    # Tensorboard summaries
    # ----------------------------------------------------
    logger.configure_tensorboard_summaries(args.save)

    # -------------------------------------------------------------------
    # From PyTorch API:
    # If you need to move a model to GPU via .cuda(), please do so before
    # constructing optimizers for it. Parameters of a model after .cuda()
    # will be different objects with those before the call.
    # In general, you should make sure that optimized parameters live in
    # consistent locations when optimizers are constructed and used.
    # -------------------------------------------------------------------
    model_and_loss = model_and_loss.to(args.device)

    # ----------------------------------------------------------
    # Configure optimizer
    # ----------------------------------------------------------
    optimizer = config.configure_optimizer(args, model_and_loss)

    # ----------------------------------------------------------
    # Configure learning rate
    # ----------------------------------------------------------
    lr_scheduler = config.configure_lr_scheduler(args, optimizer)

    # --------------------------------------------------------------------------
    # Configure parameter scheduling
    # --------------------------------------------------------------------------
    param_scheduler = config.configure_parameter_scheduler(
        args, model_and_loss)

    # quit()

    # ----------------------------------------------------------
    # Cuda optimization
    # ----------------------------------------------------------
    if args.cuda:
        torch.backends.cudnn.benchmark = constants.CUDNN_BENCHMARK

    # ----------------------------------------------------------
    # Configurate runtime visualization
    # ----------------------------------------------------------
    visualizer = config.configure_visualizers(
        args,
        model_and_loss=model_and_loss,
        optimizer=optimizer,
        param_scheduler=param_scheduler,
        lr_scheduler=lr_scheduler,
        train_loader=train_loader,
        validation_loader=validation_loader)
    if visualizer is not None:
        visualizer = visualizer.to(args.device)

    # ----------------------------------------------------------
    # Kickoff training, validation and/or testing
    # ----------------------------------------------------------
    return runtime.exec_runtime(
        args,
        checkpoint_saver=checkpoint_saver,
        lr_scheduler=lr_scheduler,
        param_scheduler=param_scheduler,
        model_and_loss=model_and_loss,
        optimizer=optimizer,
        train_loader=train_loader,
        training_augmentation=training_augmentation,
        validation_augmentation=validation_augmentation,
        validation_loader=validation_loader,
        visualizer=visualizer)
コード例 #16
0
ファイル: logger.py プロジェクト: visinf/deblur-devil
def configure_logging(filename):
    # set global indent level
    sys.modules[__name__].global_indent = 0

    # add custom tqdm logger
    add_logging_level("LOGBOOK", 1000)

    # create logger
    root_logger = logging.getLogger("")
    root_logger.setLevel(logging.INFO)

    # create console handler
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    fmt = get_default_logging_format(colorize=True, brackets=False)
    datefmt = constants.LOGGING_TIMESTAMP_FORMAT
    formatter = ConsoleFormatter(fmt=fmt, datefmt=datefmt)
    console.setFormatter(formatter)

    # Skip logging.tqdm requests for console outputs
    skip_logbook_filter = SkipLogbookFilter()
    console.addFilter(skip_logbook_filter)

    # add console to root_logger
    root_logger.addHandler(console)

    # Show warnings in logger
    logging.captureWarnings(True)

    def _log_key_value_pair(key, value=None):
        if value is None:
            logging.info("{}{}".format(constants.COLOR_KEY_VALUE, str(key)))
        else:
            logging.info("{}{}{}".format(key, constants.COLOR_KEY_VALUE,
                                         str(value)))

    def _log_dict(indict):
        for key, value in sorted(indict.items()):
            logging.info("{}: {}{}".format(key, constants.COLOR_KEY_VALUE,
                                           str(value)))

    # this is for logging key value pairs or dictionaries
    setattr(logging, "value", _log_key_value_pair)
    setattr(logging, "values", _log_dict)

    # this is for logging blocks
    setattr(logging, "block", LoggingBlock)

    # add logbook
    if filename is not None:
        # ensure dir
        d = os.path.dirname(filename)
        if not os.path.exists(d):
            os.makedirs(d)
        with logging.block("Creating Logbook", emph=True):
            logging.info(filename)

        # --------------------------------------------------------------------------
        # Configure handler that removes color codes from logbook
        # --------------------------------------------------------------------------
        logbook = logging.FileHandler(filename=filename,
                                      mode="a",
                                      encoding="utf-8")
        logbook.setLevel(logging.INFO)
        fmt = get_default_logging_format(colorize=False, brackets=True)
        logbook_formatter = LogbookFormatter(fmt=fmt, datefmt=datefmt)
        logbook.setFormatter(logbook_formatter)
        root_logger.addHandler(logbook)
コード例 #17
0
ファイル: configuration.py プロジェクト: visinf/deblur-devil
def configure_data_loaders(args):
    with logging.block("Datasets", emph=True):

        def _sizes_to_str(value):
            if np.isscalar(value):
                return '1L'
            else:
                sizes = str([d for d in value.size()])
                return ' '.join([strings.replace_index(sizes, 1, '#')])

        def _log_statistics(loader, dataset):
            example_dict = loader.first_item(
            )  # get sizes from first dataset example
            for key, value in sorted(example_dict.items()):
                if key == "index" or "name" in key:  # no need to display these
                    continue
                if isinstance(value, str):
                    logging.value("%s: " % key, value)
                elif isinstance(value, list) or isinstance(value, tuple):
                    logging.value("%s: " % key, _sizes_to_str(value[0]))
                else:
                    logging.value("%s: " % key, _sizes_to_str(value))
            logging.value("num_examples: ", len(dataset))

        # -----------------------------------------------------------------------------------------
        # GPU parameters
        # -----------------------------------------------------------------------------------------
        gpuargs = {
            "pin_memory": constants.DATALOADER_PIN_MEMORY
        } if args.cuda else {}

        train_loader_and_collation = None
        validation_loader_and_collation = None

        # -----------------------------------------------------------------
        # This figures out from the args alone, whether we need batch collcation
        # -----------------------------------------------------------------
        train_collation, validation_collation = configure_collation(args)

        # -----------------------------------------------------------------------------------------
        # Training dataset
        # -----------------------------------------------------------------------------------------
        if args.training_dataset is not None:
            # ----------------------------------------------
            # Figure out training_dataset arguments
            # ----------------------------------------------
            kwargs = typeinf.kwargs_from_args(args, "training_dataset")
            kwargs["args"] = args

            # ----------------------------------------------
            # Create training dataset and loader
            # ----------------------------------------------
            logging.value("Training Dataset: ", args.training_dataset)
            with logging.block():
                train_dataset = typeinf.instance_from_kwargs(
                    args.training_dataset_class, kwargs=kwargs)
                if args.batch_size > len(train_dataset):
                    logging.info(
                        "Problem: batch_size bigger than number of training dataset examples!"
                    )
                    quit()
                train_loader = DataLoader(
                    train_dataset,
                    batch_size=args.batch_size,
                    shuffle=constants.TRAINING_DATALOADER_SHUFFLE,
                    drop_last=constants.TRAINING_DATALOADER_DROP_LAST,
                    num_workers=args.training_dataset_num_workers,
                    **gpuargs)
                train_loader_and_collation = facade.LoaderAndCollation(
                    args, loader=train_loader, collation=train_collation)
                _log_statistics(train_loader_and_collation, train_dataset)

        # -----------------------------------------------------------------------------------------
        # Validation dataset
        # -----------------------------------------------------------------------------------------
        if args.validation_dataset is not None:
            # ----------------------------------------------
            # Figure out validation_dataset arguments
            # ----------------------------------------------
            kwargs = typeinf.kwargs_from_args(args, "validation_dataset")
            kwargs["args"] = args

            # ------------------------------------------------------
            # per default batch_size is the same as for training,
            # unless a validation_batch_size is specified.
            # -----------------------------------------------------
            validation_batch_size = args.batch_size
            if args.validation_batch_size > 0:
                validation_batch_size = args.validation_batch_size

            # ----------------------------------------------
            # Create validation dataset and loader
            # ----------------------------------------------
            logging.value("Validation Dataset: ", args.validation_dataset)
            with logging.block():
                validation_dataset = typeinf.instance_from_kwargs(
                    args.validation_dataset_class, kwargs=kwargs)
                if validation_batch_size > len(validation_dataset):
                    logging.info(
                        "Problem: validation_batch_size bigger than number of validation dataset examples!"
                    )
                    quit()
                validation_loader = DataLoader(
                    validation_dataset,
                    batch_size=validation_batch_size,
                    shuffle=constants.VALIDATION_DATALOADER_SHUFFLE,
                    drop_last=constants.VALIDATION_DATALOADER_DROP_LAST,
                    num_workers=args.validation_dataset_num_workers,
                    **gpuargs)
                validation_loader_and_collation = facade.LoaderAndCollation(
                    args,
                    loader=validation_loader,
                    collation=validation_collation)
                _log_statistics(validation_loader_and_collation,
                                validation_dataset)

    return train_loader_and_collation, validation_loader_and_collation