def main():
    base_conf = OmegaConf.load(CONFIG_FILE)
    cli_conf = OmegaConf.from_cli()
    conf = OmegaConf.merge(base_conf, cli_conf)

    dir_input = Path(conf.dir_input)
    dir_output = Path(conf.dir_output)
    input_exr_ext = conf.input_exr_ext
    mask_mapping_json = conf.mask_mapping_json
    random_seed = conf.random_seed
    random.seed(random_seed)

    output_mask_ext = conf.output_mask_ext
    output_mask_rgb_ext = conf.output_mask_rgb_ext

    if not dir_input.is_dir():
        raise ValueError(f'Not a directory: {dir_input}')
    dir_output.mkdir(parents=True, exist_ok=True)
    if Path(output_mask_ext).suffix != '.png':
        raise ValueError(f'The output mask must be in .png format. Given format: {output_mask_ext}')

    exr_filenames = sorted(dir_input.glob('*' + input_exr_ext))

    for f_exr in exr_filenames:
        process_file(f_exr, dir_output, output_mask_ext, output_mask_rgb_ext, mask_mapping_json)
Пример #2
0
def parse_args(argv=None) -> OmegaConf:
    arg_formatter = argparse.ArgumentDefaultsHelpFormatter

    description = 'Model transfer script'
    parser = argparse.ArgumentParser(formatter_class=arg_formatter,
                                     description=description)

    parser.add_argument('--exp-cfg',
                        type=str,
                        dest='exp_cfg',
                        help='The configuration of the experiment')
    parser.add_argument('--exp-opts',
                        default=[],
                        dest='exp_opts',
                        nargs='*',
                        help='Command line arguments')

    cmd_args = parser.parse_args()

    cfg = default_conf.copy()
    if cmd_args.exp_cfg:
        cfg.merge_with(OmegaConf.load(cmd_args.exp_cfg))
    if cmd_args.exp_opts:
        cfg.merge_with(OmegaConf.from_cli(cmd_args.exp_opts))

    return cfg
Пример #3
0
def config_override(cfg):
    """Overrides with user-supplied configuration

    hourly will override its configuration using
    hourly.yaml if it is in the base git directory
    or users can set an override config:
        config_override=path/to/myconfig.yaml
    """
    # change to the git directory of the original working dir
    original_path = hydra.utils.get_original_cwd()
    change_git_dir(original_path, verbosity=cfg.verbosity)

    # get the full path of the override file if available
    override_path = os.path.abspath(cfg.config_override)

    if path.exists(override_path):
        if cfg.verbosity > 0:
            print("overriding config with {}".format(override_path))
        override_conf = OmegaConf.load(override_path)
        # merge overrides first input with second
        cfg = OmegaConf.merge(cfg, override_conf)
    else:
        if cfg.verbosity > 0:
            print("override path does not exist: {}".format(override_path))

    # merge in command line arguments
    cli_conf = OmegaConf.from_cli()
    cfg = OmegaConf.merge(cfg, cli_conf)

    return cfg
Пример #4
0
def convert_to_attrdict(cfg: DictConfig, cmdline_args: List[Any] = None):
    """
    Given the user input Hydra Config, and some command line input options
    to override the config file:
    1. merge and override the command line options in the config
    2. Convert the Hydra OmegaConf to AttrDict structure to make it easy
       to access the keys in the config file
    3. Also check the config version used is compatible and supported in vissl.
       In future, we would want to support upgrading the old config versions if
       we make changes to the VISSL default config structure (deleting, renaming keys)
    4. We infer values of some parameters in the config file using the other
       parameter values.
    """
    if cmdline_args:
        # convert the command line args to DictConfig
        sys.argv = cmdline_args
        cli_conf = OmegaConf.from_cli(cmdline_args)

        # merge the command line args with config
        cfg = OmegaConf.merge(cfg, cli_conf)

    # convert the config to AttrDict
    cfg = OmegaConf.to_container(cfg)
    cfg = AttrDict(cfg)

    # check the cfg has valid version
    check_cfg_version(cfg)

    # assert the config and infer
    config = cfg.config
    assert_hydra_conf(config)
    return cfg, config
Пример #5
0
def main(flags: DictConfig):
    if os.path.exists("config.yaml"):
        # this ignores the local config.yaml and replaces it completely with saved one
        logging.info(
            "loading existing configuration, we're continuing a previous run")
        new_flags = OmegaConf.load("config.yaml")
        cli_conf = OmegaConf.from_cli()
        # however, you can override parameters from the cli still
        # this is useful e.g. if you did total_steps=N before and want to increase it
        flags = OmegaConf.merge(new_flags, cli_conf)

    logging.info(flags.pretty(resolve=True))
    OmegaConf.save(flags, "config.yaml")

    flags = get_common_flags(flags)

    # set flags for polybeast_env
    env_flags = get_environment_flags(flags)
    env_processes = []
    for _ in range(1):
        p = mp.Process(target=run_env, args=(env_flags, ))
        p.start()
        env_processes.append(p)

    symlink_latest(flags.savedir,
                   os.path.join(hydra.utils.get_original_cwd(), "latest"))

    lrn_flags = get_learner_flags(flags)
    run_learner(lrn_flags)

    for p in env_processes:
        p.kill()
        p.join()
Пример #6
0
def process_configs(argv: List[str]) -> DictConfig:
    """merge default, yaml, and CLI configs, return merged `DictConfig` object
		
	reads from:
	 - `CONFIG_DEFAULT`
	 - file specified by `CONFIG_DEFAULT['config']['file_in']`
	 - command line args
	 - file specified by command line arg `--config.file_in`
	
	Merges those configs into a single `OmegaConf.DictConfig` object, in the following order:
	 - `CONFIG_DEFAULT`
	 - file specified by `CONFIG_DEFAULT['config']['file_in']`
	 - file specified by command line arg `--config.file_in`
	 - command line args

	### Parameters:
	 - `argv : List[str]`   
	   command line arguments passed to `OmegaConf.from_cli()`
	
	### Returns:
	 - `DictConfig` 
	   merged config object
	"""

    # default options
    cfg_default: DictConfig = assert_DictConfig(CONFIG_DEFAULT)

    # try to load config from default given file
    cfg_file_default: Optional[DictConfig] = assert_Optional_DictConfig(
        load_file_config(cfg_default['config']['file_in']))

    # load command line arguments
    cfg_cmd: DictConfig = OmegaConf.from_cli(argv)

    # try to load config from cmd given file
    cfg_file_cmd: Optional[DictConfig] = None
    if ('config' in cfg_cmd) and ('file_in' in cfg_cmd['config']):
        cfg_file_cmd = assert_Optional_DictConfig(
            load_file_config(cfg_cmd['config']['file_in']))

    # merge the configs (listed in increasing priority)
    tomerge: List[Optional[DictConfig]] = [
        cfg_default, cfg_file_default, cfg_file_cmd, cfg_cmd
    ]

    tomerge_filtered: List[DictConfig] = [x for x in tomerge if x is not None]

    cfg = assert_DictConfig(OmegaConf.merge(*tomerge_filtered))

    # dont read this source file, or config files
    add_default_excludes(cfg)

    if cfg['config']['file_out'] is not None:
        # save the current (merged) config to the specified file
        OmegaConf.save(cfg, cfg['config']['file_out'])
        print('> saved config to\t%s' % cfg['config']['file_out'])

    return cfg
Пример #7
0
def load_and_validate_config(config_schema: Type) -> DictConfig:
    config_cli = OmegaConf.from_cli()
    if config_cli.config:
        config_yml = OmegaConf.load(config_cli.config)
        del config_cli["config"]
    else:
        config_yml = OmegaConf.create({})

    return OmegaConf.merge(config_schema, config_yml,
                           config_cli)  # type: ignore
Пример #8
0
def parse_configs(configs, overwrite_config):
    if len(configs) == 1:
        config = OmegaConf.load(str(configs[0]))
    else:
        config = OmegaConf.merge(
            *[OmegaConf.load(str(cfg)) for cfg in configs])
    if overwrite_config is not None:
        config.merge_with(OmegaConf.from_cli(overwrite_config))
    OmegaConf.set_readonly(config, True)
    OmegaConf.set_struct(config, True)
    return config
Пример #9
0
    def load_args(self):
        """
        There are several ways to pass arguments via the OmegaConf interface

        In general a Trainer object should have the property `default_config`
        to set the path of the default config file containing all the training
        arguments.

        - `default_config` can be overridden via cli specifying the path
            with the special `config` argument.

        """

        # retrieve module path
        dir_path = os.path.dirname(os.path.abspath(__file__))
        dir_path = os.path.split(dir_path)[0]
        # get all the default yaml configs with glob
        dir_path = os.path.join(dir_path, 'configs', '*.yml')

        # -- From default yapt configuration
        self._defaults_path = {}
        self._defaults_yapt = OmegaConf.create(dict())
        for file in glob.glob(dir_path):
            # split filename from path to create key and val
            key = os.path.splitext(os.path.split(file)[1])[0]
            self._defaults_path[key] = file
            # parse default args
            self._defaults_yapt = OmegaConf.merge(self._defaults_yapt,
                                                  OmegaConf.load(file))

        # -- From command line
        self._cli_args = OmegaConf.from_cli()
        if self._cli_args.config is not None:
            self.default_config = self._cli_args.config
            del self._cli_args['config']
            self.console_log.warning("override default config with: %s",
                                     self.default_config)

        # -- From experiment default config file
        self._default_config_args = OmegaConf.create(dict())
        if self.default_config is not None:
            self._default_config_args = OmegaConf.load(self.default_config)

        # -- Merge default args
        self._args = OmegaConf.merge(self._defaults_yapt,
                                     self._default_config_args)

        # -- Resolve interpolations to be sure all nodes are explicit
        # self._args = OmegaConf.to_container(self._args, resolve=True)
        # self._args = OmegaConf.create(self._args)

        # -- make args structured: it fails if accessing a missing key
        OmegaConf.set_struct(self._args, True)
Пример #10
0
def get_conf(cwd, conf_file):
    OmegaConf.register_resolver('now', lambda fmt: strftime(fmt, localtime()))
    config = OmegaConf.load(str(conf_file))
    conf_cli = OmegaConf.from_cli()
    for entry in config.defaults:
        assert len(entry) == 1
        for k, v in entry.items():
            if k in conf_cli:
                v = conf_cli[k]
            entry_path = cwd.parents[0] / 'config' / k / f'{v}.yaml'
            entry_conf = OmegaConf.load(str(entry_path))
            config = OmegaConf.merge(config, entry_conf)
    config = OmegaConf.merge(config, conf_cli)
    return config
Пример #11
0
def main():
    cli_conf = OmegaConf.from_cli()
    base_conf = OmegaConf.load("programmer.yaml")
    conf = OmegaConf.merge(base_conf, cli_conf)
    print("DIGIT programmer config:")
    print(conf.pretty())

    full_serial = format_serial(str(conf.digit.serial))
    set_serial(conf.digit.firmware, full_serial)

    _log.info("Flashing DIGIT firmware...")
    input("Unplug and plug DIGIT into usb then press ENTER...")
    program_digit()
    _log.info("Finished flashing firmware to DIGIT!")
Пример #12
0
def optimize(trial):
    # Get default configuration
    YAML_CONFIG = OmegaConf.load("lstm.yaml")
    CLI_CONFIG = OmegaConf.from_cli()
    DEFAULT_CONFIG = OmegaConf.merge(YAML_CONFIG, CLI_CONFIG)

    # Get hyperparameters from Optuna
    OPTUNA_CONFIG = OmegaConf.create({
        "LR":
        trial.suggest_loguniform("LR", 1e-5, 1e-2),
    })
    CONFIG = OmegaConf.merge(DEFAULT_CONFIG, OPTUNA_CONFIG)

    # Run and return validation loss
    return lstm_main(CONFIG)
Пример #13
0
def main() -> None:
    base_config = OmegaConf.structured(ExperimentConfig)
    cli_config = OmegaConf.from_cli()
    config = OmegaConf.merge(base_config, cli_config)

    # Create log directory
    time_str = time.strftime('%Y-%m-%dT%H-%M-%S')
    logdir = f'{config.logdir}/{time_str}'
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # Save additional and overall parameters
    OmegaConf.save(config, f'{logdir}/config.yaml')
    OmegaConf.save(cli_config, f'{logdir}/override.yaml')

    run(config)
def load_conf():
    """Quick method to load configuration (using OmegaConf). By default,
    configuration is loaded from the default config file (config.yaml).
    Another config file can be specific through command line.
    Also, configuration can be over-written by command line.

    Returns:
        OmegaConf.DictConfig: OmegaConf object representing the configuration.
    """
    default_conf = omg.create({"config": "config.yaml"})

    sys.argv = [a.strip("-") for a in sys.argv]
    cli_conf = omg.from_cli()

    yaml_file = omg.merge(default_conf, cli_conf).config

    yaml_conf = omg.load(yaml_file)

    return omg.merge(default_conf, yaml_conf, cli_conf)
Пример #15
0
def main():
    print('Initializing Training Process..')
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser(
        usage='\n' + '-' * 10 + ' Default config ' + '-' * 10 + '\n' +
        str(OmegaConf.to_yaml(OmegaConf.structured(TrainConfig))))
    a = parser.parse_known_args()
    override_cfg = OmegaConf.from_cli()
    base_cfg = OmegaConf.structured(TrainConfig)
    cfg: TrainConfig = OmegaConf.merge(base_cfg, override_cfg)
    logging.info(f"Running with config:\n {OmegaConf.to_yaml(cfg)}")

    torch.backends.cudnn.benchmark = True
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(cfg.seed)
        if cfg.distributed.n_gpus_per_node > torch.cuda.device_count():
            raise AssertionError((
                f" Specified n_gpus_per_node ({cfg.distributed.n_gpus_per_node})"
                f" must be less than or equal to cuda device count ({torch.cuda.device_count()}) "
            ))
        with open_dict(cfg):
            cfg.batch_size_per_gpu = int(cfg.batch_size /
                                         cfg.distributed.n_gpus_per_node)
        if cfg.batch_size % cfg.distributed.n_gpus_per_node != 0:
            logging.warn(
                ("Batch size does not evenly divide among GPUs in a node. "
                 "Likely unbalanced loads will occur."))
        logging.info(f'Batch size per GPU : {cfg.batch_size_per_gpu}')

    if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1:
        mp.spawn(train, nprocs=cfg.distributed.n_gpus_per_node, args=(cfg, ))
    else:
        train(0, cfg)
Пример #16
0
def run():
    app = QtWidgets.QApplication(sys.argv)
    app = set_style(app)

    default_path = os.path.join(os.path.dirname(__file__), 'default_config.yaml')

    default = OmegaConf.load(default_path)

    cli = OmegaConf.from_cli()
    if cli.user_cfg is not None:
        assert os.path.isfile(cli.user_cfg)
        user_cfg = OmegaConf.load(cli.user_cfg)
        cfg = OmegaConf.merge(default, user_cfg, cli)
    else:
        cfg = OmegaConf.merge(default, cli)

    OmegaConf.set_struct(cfg, True)

    window = MainWindow(cfg)
    window.resize(1024, 768)
    window.show()

    sys.exit(app.exec_())
Пример #17
0
def main():
    args_base = OmegaConf.create(default_config.default_config)

    args_cli = OmegaConf.from_cli()

    # if args_cli.config_file is not None:
    #     args_cfg = OmegaConf.load(args_cli.config_file)
    #     args_base = OmegaConf.merge(args_base, args_cfg)
    #     args_base.exp_name = os.path.splitext(os.path.basename(args_cli.config_file))[0]
    # elif args_cli.exp_name is None:
    #     raise ValueError('exp_name cannot be empty without specifying a config file')
    # del args_cli['config_file']
    args = OmegaConf.merge(args_base, args_cli)

    model = Model_pl.model.load_from_checkpoint(args.model_path)
    if args.recon_model_path:
        recon_model = Model_recon.model.load_from_checkpoint(
            args.recon_model_path)
        recon_model.is_training = False
    else:
        recon_model = None

    utils.save_depp_dist(model, args, recon_model=recon_model)
Пример #18
0
def main():
    parser = ArgumentParser('Test Step 1')
    parser.add_argument(
        '--config',
        required=False,
        help='the path of the yaml config file',
        default=None,
    )

    parser.add_argument(
        '--output',
        required=False,
        help='the path of the yaml config file',
        default=None,
    )

    args, unknown = parser.parse_known_args()

    configs = list(
        filter(
            lambda x: x is not None,
            [
                OmegaConf.structured(Config),
                OmegaConf.from_cli(unknown),
                None if args.config is None else OmegaConf.load(args.config),
            ],
        ), )

    cfg = OmegaConf.merge(*configs)
    print(OmegaConf.to_yaml(cfg))

    if args.output is not None:
        out: Output = OmegaConf.structured(Output)
        out.measurement = 1000.0
        out.user_input = "User Input"
        with open(args.output, 'w+') as ofd:
            ofd.write(OmegaConf.to_yaml(out))
Пример #19
0
def main():
    args_base = OmegaConf.create(default_config.default_config)

    args_cli = OmegaConf.from_cli()

    args = OmegaConf.merge(args_base, args_cli)
    original_distance = pd.read_csv(os.path.join(args.outdir, "depp.csv"),
                                    sep='\t')
    a_for_seq_name = pd.read_csv(os.path.join(args.outdir, "depp.csv"),
                                 sep='\t',
                                 dtype=str)
    s = list(original_distance.keys())[1:]
    tree = treeswift.read_tree(args.backbone_tree, 'newick')
    true_max = tree.diameter()
    # print(true_max)
    data = {}
    s_set = set(s)
    for i in range(len(original_distance)):
        line = list(a_for_seq_name.iloc[i])
        seq_name = line[0]
        with open(f"{args.outdir}/depp_tmp/{seq_name}_leaves.txt", "r") as f:
            method = set(f.read().split("\n"))
            method.remove('')
            method = method.intersection(s_set)
        if method:
            query_median = np.median(
                original_distance[np.array(method)].iloc[i])
            ratio = true_max / (query_median + 1e-7)
            # print(ratio)
            b = original_distance.iloc[i].values[1:] * ratio
        else:
            b = original_distance.iloc[i].values[1:]
        seq_dict = dict(zip(s, b))
        data[seq_name] = seq_dict
    data = pd.DataFrame.from_dict(data, orient='index', columns=s)
    data.to_csv(os.path.join(args.outdir, f'depp_correction.csv'), sep='\t')
Пример #20
0
def test_create_from_cli():
    sys.argv = ['program.py', 'a=1', 'b.c=2']
    c = OmegaConf.from_cli()
    assert {'a': 1, 'b': {'c': 2}} == c
Пример #21
0
def main():
    args_base = OmegaConf.create(default_config.default_config)

    args_cli = OmegaConf.from_cli()

    # if args_cli.config_file is not None:
    #     args_cfg = OmegaConf.load(args_cli.config_file)
    #     args_base = OmegaConf.merge(args_base, args_cfg)
    #     args_base.exp_name = os.path.splitext(os.path.basename(args_cli.config_file))[0]
    # elif args_cli.exp_name is None:
    #     raise ValueError('exp_name cannot be empty without specifying a config file')
    # del args_cli['config_file']
    args = OmegaConf.merge(args_base, args_cli)

    model_dir = args.model_dir
    os.makedirs(model_dir, exist_ok=True)

    model = Model_pl.model(args=args)

    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        min_delta=0.00,
        patience=args.patience,
        verbose=False,
        mode='min'
    )

    checkpoint_callback = ModelCheckpoint(
        filepath=model_dir,
        save_top_k=1,
        verbose=True,
        monitor='val_loss',
        mode='min',
        prefix=''
    )
    print(model_dir)
    if args.gpus == 0:
        trainer = pl.Trainer(
            logger=False,
            gpus=args.gpus,
            progress_bar_refresh_rate=args.bar_update_freq,
            check_val_every_n_epoch=args.val_freq,
            max_epochs=args.epoch,
            gradient_clip_val=args.cp,
            benchmark=True,
            callbacks=[early_stop_callback],
            checkpoint_callback=checkpoint_callback,
            # reload_dataloaders_every_epoch=True
        )
    else:
        trainer = pl.Trainer(
            logger=False,
            gpus=args.gpus,
            progress_bar_refresh_rate=args.bar_update_freq,
            distributed_backend='ddp',
            check_val_every_n_epoch=args.val_freq,
            max_epochs=args.epoch,
            gradient_clip_val=args.cp,
            benchmark=True,
            callbacks=[early_stop_callback],
            checkpoint_callback=checkpoint_callback,
            # reload_dataloaders_every_epoch=True
        )

    trainer.fit(model)
Пример #22
0
def main(argv):
    # automatically parses arguments of export_scenes
    cfg = OmegaConf.create(get_default_args(export_scenes))
    cfg.update(OmegaConf.from_cli())
    with torch.no_grad():
        export_scenes(**cfg)
Пример #23
0
    train, test, submit = utils.read_data("./data")
    utils.info('read wave data...')

    test_wave = utils.read_wave("./data/ecg/" + test["Id"] + ".npy")

    train["sex"] = train["sex"].replace({"male": 0, "female": 1})
    test["sex"] = test["sex"].replace({"male": 0, "female": 1})

    test_preds = np.zeros(
        [
            N_FOLD,
            test_wave.shape[0]
        ]
    )

    for fold in range(N_FOLD):
        utils.info('predict', fold)

        model = MODEL_NAMES_DICT[param.model_name](param)

        test_preds[fold] = model.predict([test_wave, test[["sex", "age"]]], fold)

    submit["target"] = test_preds.mean(axis=0)
    submit.to_csv("./logs/{}/submission.csv".format(param.model_name), index=False)


if __name__ == '__main__':
    param = OmegaConf.from_cli()
    utils.info('params:', param)
    main(param)
Пример #24
0
def test_create_from_cli() -> None:
    sys.argv = ["program.py", "a=1", "b.c=2"]
    c = OmegaConf.from_cli()
    assert {"a": 1, "b": {"c": 2}} == c
Пример #25
0
def test_cli_passing() -> None:
    args_list = ["a=1", "b.c=2"]
    c = OmegaConf.from_cli(args_list)
    assert {"a": 1, "b": {"c": 2}} == c
Пример #26
0
    wandb.watch(net)

    # Log number of parameters
    CONFIG.NUM_PARAMETERS = count_parameters(net)
    trainer = pl.Trainer(
        # TODO: Add CONFIG parameters for devices
        gpus=1,
        # Don't show progress bar
        progress_bar_refresh_rate=0,
        check_val_every_n_epoch=1,
        # TODO: Try early stopping
        max_epochs=CONFIG.NUM_EPOCH,
    )
    trainer.fit(net)

    # Close wandb
    wandb.join()

    # Used in `tune_hyperparameters.py`
    # TODO: How to get numbers?
    # return test_mse_loss


if __name__ == "__main__":
    # Load Configuration
    YAML_CONFIG = OmegaConf.load("lstm.yaml")
    CLI_CONFIG = OmegaConf.from_cli()
    CONFIG = OmegaConf.merge(YAML_CONFIG, CLI_CONFIG)

    main(CONFIG)
Пример #27
0
    DEFAULT_DIR = "~/ray_results/coord_game"
    tune.run(
        SIMSiNFSPTrainer,
        config=train_config,
        local_dir=DEFAULT_DIR,
        stop={"timesteps_total": 3e6},
        checkpoint_at_end=True,
        num_samples=cl_args.num_samples,
        loggers=DEFAULT_LOGGERS + (WandbLogger, ),
    )

    ray.shutdown()


if __name__ == "__main__":
    cl_args = OmegaConf.from_cli()
    input_args = OmegaConf.create({**DEFAULT_CONFIG, **dict(cl_args)})

    dict_args = {
        'env_config': {
            'horizon': 2,
            'relaxed': True,
            'actions_payoff': [(0, 100.), (1, 50.)],
        },
        'discrete_env_config': {
            'horizon': 2,
            'actions_payoff': [(0, 100.), (1, 50.)],
            'n_signals': 2
        },
        'train_batch_size': 128,
    }
Пример #28
0
def test_cli_passing():
    args_list = ['a=1', 'b.c=2']
    c = OmegaConf.from_cli(args_list)
    assert {'a': 1, 'b': {'c': 2}} == c
Пример #29
0
def main():
    # Load Configuration
    YAML_CONFIG = OmegaConf.load("fc.yaml")
    CLI_CONFIG = OmegaConf.from_cli()
    CONFIG = OmegaConf.merge(YAML_CONFIG, CLI_CONFIG)

    # Reproducibility
    random.seed(CONFIG.SEED)
    np.random.seed(CONFIG.SEED)
    torch.manual_seed(CONFIG.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Setup CPU or GPU
    if CONFIG.USE_GPU and not torch.cuda.is_available():
        raise ValueError("GPU not detected but CONFIG.USE_GPU is set to True.")
    device = torch.device("cuda" if CONFIG.USE_GPU else "cpu")

    # Setup dataset and dataloader
    # NOTE(seungjaeryanlee): Load saved dataset for speed
    # dataset = get_dataset()
    dataset = load_dataset()
    train_size = int(0.6 * len(dataset))
    valid_size = int(0.2 * len(dataset))
    test_size = len(dataset) - train_size - valid_size
    train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, valid_size, test_size])
    kwargs = {'num_workers': 1, 'pin_memory': True} if CONFIG.USE_GPU else {}
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=CONFIG.BATCH_SIZE,
                                               shuffle=True,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=CONFIG.BATCH_SIZE,
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=CONFIG.BATCH_SIZE,
                                              shuffle=False,
                                              **kwargs)

    # Setup neural network and optimizer
    net = Net().double().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=CONFIG.LR)
    # Log number of parameters
    CONFIG.NUM_PARAMETERS = count_parameters(net)

    # Setup wandb
    wandb.init(project="MagNet", config=CONFIG)
    wandb.watch(net)

    # Training
    for epoch_i in range(1, CONFIG.NUM_EPOCH + 1):
        # Train for one epoch
        epoch_train_loss = 0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = net(inputs.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()

        # Compute Validation Loss
        with torch.no_grad():
            epoch_valid_loss = 0
            for inputs, labels in valid_loader:
                outputs = net(inputs.to(device))
                loss = criterion(outputs, labels.to(device))

                epoch_valid_loss += loss.item()

        print(f"Epoch {epoch_i:2d} "
              f"Train {epoch_train_loss / len(train_dataset):.5f} "
              f"Valid {epoch_valid_loss / len(valid_dataset):.5f}")
        wandb.log({
            "train/loss": epoch_train_loss / len(train_dataset),
            "valid/loss": epoch_valid_loss / len(valid_dataset),
        })

    # Evaluation
    net.eval()
    y_meas = []
    y_pred = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            y_pred.append(net(inputs.to(device)))
            y_meas.append(labels.to(device))

    y_meas = torch.cat(y_meas, dim=0)
    y_pred = torch.cat(y_pred, dim=0)
    print(
        f"Test Loss: {F.mse_loss(y_meas, y_pred).item() / len(test_dataset):.8f}"
    )
    wandb.log(
        {"test/loss": F.mse_loss(y_meas, y_pred).item() / len(test_dataset)})

    # Analysis
    wandb.log({
        "test/prediction_vs_target":
        wandb.Image(analysis.get_scatter_plot(y_pred, y_meas)),
        "test/prediction_vs_target_histogram":
        wandb.Image(analysis.get_two_histograms(y_pred, y_meas)),
        "test/error_histogram":
        wandb.Image(analysis.get_error_histogram(y_pred, y_meas)),
    })
Пример #30
0
    N_MELS: int = 64
    MEL_FMIN: int = 50
    MEL_FMAX: int = 14000
    # frontend
    TOP_K: int = 6
    TITLE_FONTSIZE: int = 28
    TABLE_FONTSIZE: int = 22


# ##############################################################################
# # MAIN ROUTINE
# ##############################################################################
if __name__ == '__main__':

    CONF = OmegaConf.structured(ConfDef())
    cli_conf = OmegaConf.from_cli()
    CONF = OmegaConf.merge(CONF, cli_conf)
    print("\n\nCONFIGURATION:")
    print(OmegaConf.to_yaml(CONF), end="\n\n\n")

    _, _, all_labels = load_csv_labels(CONF.ALL_LABELS_PATH)
    if CONF.SUBSET_LABELS_PATH is None:
        subset_labels = None
    else:
        _, _, subset_labels = load_csv_labels(CONF.SUBSET_LABELS_PATH)
    logo_paths = [SURREY_LOGO_PATH, CVSSP_LOGO_PATH, EPSRC_LOGO_PATH]

    demo = DemoApp(AI4S_BANNER_PATH, logo_paths, CONF.MODEL_PATH, all_labels,
                   subset_labels, CONF.SAMPLERATE, CONF.AUDIO_CHUNK_LENGTH,
                   CONF.RINGBUFFER_LENGTH, CONF.MODEL_WINSIZE,
                   CONF.STFT_HOPSIZE, CONF.STFT_WINDOW, CONF.N_MELS,