コード例 #1
0
def get_storage_dir():
    # Sacred should not add path_template to the config
    # -> move this few lines to a function
    path_template = Path(os.environ["STORAGE"]) / "pth_evaluate" / nickname
    return str(get_new_folder(
        path_template, mkdir=False, consider_mpi=True,
    ))
コード例 #2
0
ファイル: train.py プロジェクト: entn-at/padertorch
def config():
    debug = False
    batch_size = 6

    train_dataset = "mix_2_spk_min_tr"
    validate_dataset = "mix_2_spk_min_cv"

    # Start with an empty dict to allow tracking by Sacred
    trainer = {
        "model": {
            "factory": pt.models.bss.PermutationInvariantTrainingModel,
            "dropout_input": 0.,
            "dropout_hidden": 0.,
            "dropout_linear": 0.
        },
        "storage_dir": None,
        "optimizer": {
            "factory": pt.optimizer.Adam,
            "gradient_clipping": 1
        },
        "summary_trigger": (1000, "iteration"),
        "stop_trigger": (300_000, "iteration"),
        "loss_weights": {
            "pit_ips_loss": 1.0,
            "pit_mse_loss": 0.0,
        }
    }
    pt.Trainer.get_config(trainer)
    if trainer['storage_dir'] is None:
        trainer['storage_dir'] = get_new_folder(path_template, mkdir=False)

    ex.observers.append(
        FileStorageObserver.create(Path(trainer['storage_dir']) / 'sacred'))
コード例 #3
0
def config():
    debug = False
    model_path = ''
    assert len(model_path) > 0, 'Set the model path on the command line.'
    checkpoint_name = 'ckpt_best_loss.pth'
    experiment_dir = str(
        get_new_folder(
            path_template,
            mkdir=False,
            consider_mpi=True,
        ))
    batch_size = 1
    datasets = ["mix_2_spk_min_cv", "mix_2_spk_min_tt"]
    locals()  # Fix highlighting
コード例 #4
0
def config():
    debug = False
    database_json = ''
    if "WSJ0_2MIX" in os.environ:
        database_json = os.environ.get("WSJ0_2MIX")
    assert len(
        database_json
    ) > 0, 'Set path to database Json on the command line or set environment variable WSJ0_2MIX'
    model_path = ''
    assert len(model_path) > 0, 'Set the model path on the command line.'
    checkpoint_name = 'ckpt_best_loss.pth'
    experiment_dir = str(
        get_new_folder(
            path_template,
            mkdir=False,
            consider_mpi=True,
        ))
    batch_size = 1
    datasets = ["mix_2_spk_min_cv", "mix_2_spk_min_tt"]
    locals()  # Fix highlighting
コード例 #5
0
def config():
    debug = False
    batch_size = 6
    database_json = ""  # Path to WSJ0_2mix .json
    if "WSJ0_2MIX" in os.environ:
        database_json = os.environ.get("WSJ0_2MIX")
    assert len(database_json) > 0, 'Set path to database Json on the command line or set environment variable WSJ0_2MIX'
    train_dataset = "mix_2_spk_min_tr"
    validate_dataset = "mix_2_spk_min_cv"

    # dict describing the model parameters, to allow changing the paramters from the command line.
    # Configurable automatically inserts default values of not mentioned parameters to the config.json
    trainer = {
        "model": {
            "factory": pt.contrib.examples.pit.model.PermutationInvariantTrainingModel,
            "dropout_input": 0.,
            "dropout_hidden": 0.,
            "dropout_linear": 0.
        },
        "storage_dir": None,
        "optimizer": {
            "factory": pt.optimizer.Adam,
            "gradient_clipping": 1
        },
        "summary_trigger": (1000, "iteration"),
        "stop_trigger": (300_000, "iteration"),
        "loss_weights": {
            "pit_ips_loss": 1.0,
            "pit_mse_loss": 0.0,
        }
    }
    pt.Trainer.get_config(trainer)
    if trainer['storage_dir'] is None:
        trainer['storage_dir'] = get_new_folder(path_template, mkdir=False)

    ex.observers.append(FileStorageObserver.create(
        Path(trainer['storage_dir']) / 'sacred')
    )
コード例 #6
0
def get_storage_dir():
    # Sacred should not add path_template to the config
    # -> move this few lines to a function
    path_template = Path(os.environ["STORAGE"]) / 'pth_models' / nickname
    path_template.mkdir(exist_ok=True, parents=True)
    return get_new_folder(path_template, mkdir=False)