示例#1
0
def config():
    debug = False
    model_path = ''
    assert len(model_path) > 0, 'Set the model path on the command line.'
    checkpoint_name = 'ckpt_best_loss.pth'
    experiment_dir = get_storage_dir()
    datasets = ["mix_2_spk_min_cv", "mix_2_spk_min_tt"]
    export_audio = False
    sample_rate = 8000
    target = 'speech_source'
    database_json = None
    oracle_num_spk = False  # If true, the model is forced to perform the correct (oracle) number of iterations
    max_iterations = 4  # The number of iterations is limited to this number

    if database_json is None:
        raise MissingConfigError(
            'You have to set the path to the database JSON!', 'database_json')
    if not Path(database_json).exists():
        raise InvalidConfigError('The database JSON does not exist!',
                                 'database_json')

    locals()  # Fix highlighting

    ex.observers.append(
        FileStorageObserver(Path(Path(experiment_dir) / 'sacred')))
示例#2
0
def config():
    debug = False

    database_json = None
    if database_json is None and 'NT_DATABASE_JSONS_DIR' in os.environ:
        database_json = Path(
            os.environ.get('NT_DATABASE_JSONS_DIR')) / 'wsj0_2mix_8k.json'

    assert len(database_json) > 0, (
        'Set path to database Json on the command line or set environment '
        'variable "NT_DATABASE_JSONS_DIR"')
    model_path = ''
    checkpoint_name = 'ckpt_best_loss.pth'
    experiment_dir = None
    if experiment_dir is None:
        experiment_dir = pt.io.get_new_subdir(Path(model_path) / 'evaluation',
                                              consider_mpi=True)
    batch_size = 1
    datasets = ["mix_2_spk_min_cv", "mix_2_spk_min_tt"]
    locals()  # Fix highlighting

    ex.observers.append(
        FileStorageObserver(Path(Path(experiment_dir) / 'sacred')))
    if database_json is None:
        raise MissingConfigError(
            'You have to set the path to the database JSON!', 'database_json')
    if not Path(database_json).exists():
        raise InvalidConfigError('The database JSON does not exist!',
                                 'database_json')
示例#3
0
def config():
    debug = False
    dump_audio = False

    # Model config
    model_path = ''
    assert len(model_path) > 0, 'Set the model path on the command line.'
    checkpoint_name = 'ckpt_best_loss.pth'
    experiment_dir = None
    if experiment_dir is None:
        experiment_dir = pt.io.get_new_subdir(Path(model_path) / 'evaluation',
                                              consider_mpi=True)

    # Database config
    database_json = None
    if database_json is None and JSON_BASE:
        database_json = Path(JSON_BASE) / 'wsj0_2mix_8k.json'

    sample_rate = 8000
    datasets = ['mix_2_spk_min_cv', 'mix_2_spk_min_tt']
    target = 'speech_source'

    if database_json is None:
        raise MissingConfigError(
            'You have to set the path to the database JSON!', 'database_json')
    if not Path(database_json).exists():
        raise InvalidConfigError('The database JSON does not exist!',
                                 'database_json')

    locals()  # Fix highlighting

    ex.observers.append(
        FileStorageObserver(Path(Path(experiment_dir) / 'sacred')))
示例#4
0
 def _assert_no_missing_args(self, args, kwargs, bound):
     free_params = self.get_free_parameters(args, kwargs, bound)
     missing_args = [m for m in free_params if m not in self.kwargs]
     if missing_args:
         raise MissingConfigError('{} is missing value(s):'.format(
             self.name),
                                  missing_configs=missing_args)
示例#5
0
def config():
    debug = False
    batch_size = 4  # Runs on 4GB GPU mem. Can safely be set to 12 on 12 GB (e.g., GTX1080)
    chunk_size = 32000  # 4s chunks @8kHz

    train_dataset = "mix_2_spk_min_tr"
    validate_dataset = "mix_2_spk_min_cv"
    target = 'speech_source'
    lr_scheduler_step = 2
    lr_scheduler_gamma = 0.98
    load_model_from = None
    database_json = None
    if database_json is None and JSON_BASE:
        database_json = Path(JSON_BASE) / 'wsj0_2mix_8k.json'

    if database_json is None:
        raise MissingConfigError(
            'You have to set the path to the database JSON!', 'database_json')
    if not Path(database_json).exists():
        raise InvalidConfigError('The database JSON does not exist!',
                                 'database_json')

    feat_size = 64
    encoder_window_size = 16
    trainer = {
        "model": {
            "factory": padertorch.contrib.examples.source_separation.tasnet.TasNet,
            'encoder': {
                'factory': padertorch.contrib.examples.source_separation.tasnet.tas_coders.TasEncoder,
                'window_length': encoder_window_size,
                'feature_size': feat_size,
            },
            'decoder': {
                'factory': padertorch.contrib.examples.source_separation.tasnet.tas_coders.TasDecoder,
                'window_length': encoder_window_size,
                'feature_size': feat_size,
            },
        },
        "storage_dir": None,
        "optimizer": {
            "factory": pt.optimizer.Adam,
            "gradient_clipping": 1
        },
        "summary_trigger": (1000, "iteration"),
        "stop_trigger": (100, "epoch"),
        "loss_weights": {
            "si-sdr": 1.0,
            "log-mse": 0.0,
            "log1p-mse": 0.0,
        }
    }
    pt.Trainer.get_config(trainer)
    if trainer['storage_dir'] is None:
        trainer['storage_dir'] = pt.io.get_new_storage_dir(experiment_name)

    ex.observers.append(FileStorageObserver(
        Path(trainer['storage_dir']) / 'sacred')
    )
示例#6
0
def config():
    debug = False
    batch_size = 6
    database_json = None  # Path to WSJ0_2mix .json
    if database_json is None and JSON_BASE:
        database_json = Path(JSON_BASE) / 'wsj0_2mix_8k.json'

    if database_json is None:
        raise MissingConfigError(
            'You have to set the path to the database JSON!', 'database_json')
    if not Path(database_json).exists():
        raise InvalidConfigError('The database JSON does not exist!',
                                 'database_json')
    train_dataset = "mix_2_spk_min_tr"
    validate_dataset = "mix_2_spk_min_cv"

    # Dict describing the model parameters, to allow changing the parameters from the command line.
    # Configurable automatically inserts the default values of not mentioned parameters to the config.json
    trainer = {
        "model": {
            "factory": pt.contrib.examples.source_separation.pit.model.
            PermutationInvariantTrainingModel,
            "dropout_input": 0.,
            "dropout_hidden": 0.,
            "dropout_linear": 0.
        },
        "storage_dir": None,
        "optimizer": {
            "factory": pt.optimizer.Adam,
            "gradient_clipping": 1
        },
        "summary_trigger": (1000, "iteration"),
        "stop_trigger": (300_000, "iteration"),
        "loss_weights": {
            "pit_ips_loss": 1.0,
            "pit_mse_loss": 0.0,
        }
    }
    pt.Trainer.get_config(trainer)
    if trainer['storage_dir'] is None:
        trainer['storage_dir'] = pt.io.get_new_storage_dir(experiment_name)

    ex.observers.append(
        FileStorageObserver(Path(trainer['storage_dir']) / 'sacred'))
示例#7
0
def config():
    debug = False
    batch_size = 4  # Runs on 4GB GPU mem. Can safely be set to 12 on 12 GB (e.g., GTX1080)
    chunk_size = 32000  # 4s chunks @8kHz

    train_dataset = "mix_2_spk_min_tr"
    validate_dataset = "mix_2_spk_min_cv"
    target = 'speech_source'
    lr_scheduler_step = 2
    lr_scheduler_gamma = 0.98
    load_model_from = None
    database_json = None

    if database_json is None:
        raise MissingConfigError(
            'You have to set the path to the database JSON!', 'database_json')
    if not Path(database_json).exists():
        raise InvalidConfigError('The database JSON does not exist!',
                                 'database_json')

    # Start with an empty dict to allow tracking by Sacred
    trainer = {
        "model": {
            "factory": 'padertorch.contrib.examples.tasnet.tasnet.TasNet',
        },
        "storage_dir": None,
        "optimizer": {
            "factory": pt.optimizer.Adam,
            "gradient_clipping": 1
        },
        "summary_trigger": (1000, "iteration"),
        "stop_trigger": (100_000, "iteration"),
        "loss_weights": {
            "si-sdr": 1.0,
            "log-mse": 0.0,
            "si-sdr-grad-stop": 0.0,
        }
    }
    pt.Trainer.get_config(trainer)
    if trainer['storage_dir'] is None:
        trainer['storage_dir'] = get_storage_dir()

    ex.observers.append(
        FileStorageObserver(Path(trainer['storage_dir']) / 'sacred'))
示例#8
0
def config():
    debug = False

    # Model config
    model_path = ''     # Required path to the model to evaluate
    assert len(model_path) > 0, 'Set the model path on the command line.'
    checkpoint_name = 'ckpt_best_loss.pth'
    experiment_dir = None
    if experiment_dir is None:
        experiment_dir = pt.io.get_new_subdir(
            Path(model_path) / 'evaluation', consider_mpi=True)

    # Data config
    database_json = None
    if database_json is None and JSON_BASE:
        database_json = Path(JSON_BASE) / 'wsj0_2mix_8k.json'
    datasets = ["mix_2_spk_min_cv", "mix_2_spk_min_tt"]
    target = 'speech_source'
    sample_rate = 8000

    if database_json is None:
        raise MissingConfigError(
            'You have to set the path to the database JSON!', 'database_json')
    if not Path(database_json).exists():
        raise InvalidConfigError('The database JSON does not exist!',
                                 'database_json')

    # Evaluation options
    dump_audio = False    # If true, exports the separated audio files into a sub-directory "audio"
    oracle_num_spk = False  # If true, the model is forced to perform the correct (oracle) number of iterations
    max_iterations = 4  # The number of iterations is limited to this number

    locals()  # Fix highlighting

    ex.observers.append(FileStorageObserver(
        Path(Path(experiment_dir) / 'sacred')
    ))