Exemplo n.º 1
0
def main(_run, _log, trainer, database_json, training_sets, validation_sets,
         audio_reader, stft, max_length_in_sec, batch_size, resume):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    db = JsonDatabase(database_json)
    training_data = db.get_dataset(training_sets)
    validation_data = db.get_dataset(validation_sets)
    training_data = prepare_dataset(training_data,
                                    audio_reader=audio_reader,
                                    stft=stft,
                                    max_length_in_sec=max_length_in_sec,
                                    batch_size=batch_size,
                                    shuffle=True)
    validation_data = prepare_dataset(validation_data,
                                      audio_reader=audio_reader,
                                      stft=stft,
                                      max_length_in_sec=max_length_in_sec,
                                      batch_size=batch_size,
                                      shuffle=False)

    trainer.test_run(training_data, validation_data)
    trainer.register_validation_hook(validation_data)
    trainer.train(training_data, resume=resume)
Exemplo n.º 2
0
def main(_run, _log, trainer, database_json, training_set, validation_metric,
         maximize_metric, audio_reader, stft, num_workers, batch_size,
         max_padding_rate, resume):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    training_data, validation_data, _ = get_datasets(
        database_json=database_json,
        min_signal_length=1.5,
        audio_reader=audio_reader,
        stft=stft,
        num_workers=num_workers,
        batch_size=batch_size,
        max_padding_rate=max_padding_rate,
        training_set=training_set,
        storage_dir=storage_dir,
        stft_stretch_factor_sampling_fn=Uniform(low=0.5, high=1.5),
        stft_segment_length=audio_reader['target_sample_rate'],
        stft_segment_shuffle_prob=0.,
        mixup_probs=(1 / 2, 1 / 2),
        max_mixup_length=15.,
        min_mixup_overlap=.8,
    )

    trainer.test_run(training_data, validation_data)
    trainer.register_validation_hook(validation_data,
                                     metric=validation_metric,
                                     maximize=maximize_metric)
    trainer.train(training_data, resume=resume)
Exemplo n.º 3
0
def main(_run, seed, save_folder, config_filename):
    set_global_seeds(seed)
    logger.info("Run id: {}".format(_run._id))

    print_config(ex.current_run)

    # saving config
    save_config(ex.current_run.config, ex.logger,
                config_filename=save_folder + config_filename)
    train()
Exemplo n.º 4
0
def main(_run, seed, save_folder, config_filename):
    set_global_seeds(seed)
    logger.info("Run id: {}".format(_run._id))

    print_config(ex.current_run)

    # saving config
    save_config(ex.current_run.config,
                ex.logger,
                config_filename=save_folder + config_filename)
    train()
Exemplo n.º 5
0
def main(_run, _log, trainer, database_json, dataset, batch_size):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    train_set, validate_set, _ = get_datasets(storage_dir, database_json,
                                              dataset, batch_size)

    # Early stopping if loss is not decreasing after three consecutive validation
    # runs. Typically around 20k iterations (13 epochs) with an accuracy >98%
    # on the test set.
    trainer.register_validation_hook(validate_set, early_stopping_patience=3)
    trainer.test_run(train_set, validate_set)
    trainer.train(train_set)
Exemplo n.º 6
0
def set_up_loging(exp_path, _config, _run, loglevel='INFO'):
    spath = os.path.join(exp_path, 'scources')
    lpath = os.path.join(exp_path, 'log.txt')
    cpath = os.path.join(exp_path, 'config.json')

    for src in (glob.glob('./*.py') + glob.glob('./*/*.py')):
        dst = os.path.join(spath, src[2:])
        mkdir(dst)
        shutil.copy(src, dst)

    mkdir(lpath)
    handler = logging.FileHandler(lpath)
    handler.setFormatter(
        logging.Formatter(fmt='%(asctime)s %(levelname)s: %(message)s',
                          datefmt='%m-%d %H:%M:%S'))
    _run.run_logger.setLevel(loglevel)
    _run.run_logger.addHandler(handler)

    mkdir(cpath)
    save_config(_run.config, _run.run_logger, cpath)
    _run.run_logger.info(_format_config(_run.config,
                                        _run.config_modifications))