示例#1
0
def test_force_aligned_sequential(tmpdir):
    logger.configure_logger(out_folder=tmpdir)

    dataset = KaldiDataset(
        data_cache_root=tmpdir,
        dataset_name="TIMIT_tr",
        feature_dict=_feature_dict,
        label_dict=_lab_dict_cd_phn,
        device='cpu',
        max_sample_len=1000,
        left_context=10,
        right_context=2,
        normalize_features=True,
        phoneme_dict=get_phoneme_dict(
            f"{KALDI_ROOT}/egs/timit/s5/data/lang/phones.txt",
            stress_marks=True,
            word_position_dependency=False),
        split_files_max_seq_len=100)
    feat_len = 0
    for filename, features, lables in dataset:
        for feature_name in features:
            feat_len = features[feature_name].shape[0]
            assert list(features[feature_name].shape[1:]) == [40, 13]
        for label_name in lables:
            assert lables[label_name].shape[0] == feat_len
示例#2
0
    def __init__(self, keywords, sensitivity, model_path) -> None:
        super().__init__()
        self.tmp_root_dir = '/mnt/data/tmp_kws_eval'
        if not os.path.exists(self.tmp_root_dir):
            os.makedirs(self.tmp_root_dir)
        self.tmp_dir = tempfile.TemporaryDirectory(dir=self.tmp_root_dir)

        # TODO debug mode

        logger.configure_logger(self.tmp_dir.name)
        check_environment()

        assert isinstance(keywords, dict), keywords
        self.keywords = keywords
        self.sensitivity = sensitivity

        self.model_checkpoint_path = model_path
        self.decoder = get_decoder(model_path, keywords, self.tmp_dir.name)
示例#3
0
def main(config_path, load_path, restart, overfit_small_batch, warm_start,
         optim_overwrite):
    config = read_json(config_path)
    check_config(config)
    if optim_overwrite:
        optim_overwrite = read_json('cfg/optim_overwrite.json')

    if load_path is not None:
        raise NotImplementedError

    # if resume_path:
    # TODO
    #     resume_config = torch.load(folder_to_checkpoint(args.resume), map_location='cpu')['config']
    #     # also the results won't be the same give the different random seeds with different number of draws
    #     del config['exp']['name']
    #     recursive_update(resume_config, config)
    #
    #     print("".join(["="] * 80))
    #     print("Resume with these changes in the config:")
    #     print("".join(["-"] * 80))
    #     print(jsondiff.diff(config, resume_config, dump=True, dumper=jsondiff.JsonDumper(indent=1)))
    #     print("".join(["="] * 80))
    #
    #     config = resume_config
    #     # start_time = datetime.datetime.now().strftime('_%Y%m%d_%H%M%S')
    #     # config['exp']['name'] = config['exp']['name'] + "r-" + start_time
    # else:
    save_time = datetime.datetime.now().strftime('_%Y%m%d_%H%M%S')
    # config['exp']['name'] = config['exp']['name'] + start_time

    set_seed(config['exp']['seed'])

    config['exp']['save_dir'] = os.path.abspath(config['exp']['save_dir'])

    # Output folder creation
    out_folder = os.path.join(config['exp']['save_dir'], config['exp']['name'])
    if os.path.exists(out_folder):
        print(
            f"Experiement under {out_folder} exists, moving it copying it to backup"
        )
        if os.path.exists(os.path.join(out_folder, "checkpoints")) \
                and len(os.listdir(os.path.join(out_folder, "checkpoints"))) > 0:
            shutil.copytree(
                out_folder,
                os.path.join(
                    config['exp']['save_dir'] + "_finished_runs_backup/",
                    config['exp']['name'] + save_time))

        #     print(os.listdir(os.path.join(out_folder, "checkpoints")))
        #     resume_path = out_folder
        # else:
        if restart:
            shutil.rmtree(out_folder)
            os.makedirs(out_folder + '/exp_files')
    else:
        os.makedirs(out_folder + '/exp_files')

    logger.configure_logger(out_folder)

    check_environment()

    if nvidia_smi_enabled:  # TODO chage criteria or the whole thing
        git_commit = code_versioning()
        if 'versioning' not in config:
            config['versioning'] = {}
        config['versioning']['git_commit'] = git_commit

    logger.info("Experiment name : {}".format(out_folder))
    logger.info("tensorboard : tensorboard --logdir {}".format(
        os.path.abspath(out_folder)))

    model, loss, metrics, optimizers, config, lr_schedulers, seq_len_scheduler = setup_run(
        config, optim_overwrite)

    if warm_start is not None:
        load_warm_start_op = getattr(model, "load_warm_start", None)
        assert callable(load_warm_start_op)
        model.load_warm_start(warm_start)

    # TODO instead of resuming and making a new folder, make a backup and continue in the same folder
    trainer = Trainer(model,
                      loss,
                      metrics,
                      optimizers,
                      lr_schedulers,
                      seq_len_scheduler,
                      load_path,
                      config,
                      restart_optim=bool(optim_overwrite),
                      do_validation=True,
                      overfit_small_batch=overfit_small_batch)
    trainer.train()