Exemple #1
0
    args = parser.parse_args()

    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    # Set random seed
    torch.manual_seed(777)
    print(args)
    with open(args.config, 'r') as fin:
        configs = yaml.load(fin, Loader=yaml.FullLoader)

    distributed = args.world_size > 1

    raw_wav = configs['raw_wav']

    train_collate_func = CollateFunc(**configs['collate_conf'],
                                     raw_wav=raw_wav)

    cv_collate_conf = copy.deepcopy(configs['collate_conf'])
    # no augmenation on cv set
    cv_collate_conf['spec_aug'] = False
    cv_collate_conf['spec_sub'] = False
    if raw_wav:
        cv_collate_conf['feature_dither'] = 0.0
        cv_collate_conf['speed_perturb'] = False
        cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0
    cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)

    dataset_conf = configs.get('dataset_conf', {})
    train_dataset = AudioDataset(args.train_data,
                                 **dataset_conf,
                                 raw_wav=raw_wav)
Exemple #2
0
                args.mode))
        sys.exit(1)

    with open(args.config, 'r') as fin:
        configs = yaml.load(fin)

    raw_wav = configs['raw_wav']
    # Init dataset and data loader
    # Init dataset and data loader
    test_collate_conf = copy.deepcopy(configs['collate_conf'])
    test_collate_conf['spec_aug'] = False
    test_collate_conf['spec_sub'] = False
    test_collate_conf['feature_dither'] = False
    test_collate_conf['speed_perturb'] = False
    test_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0
    test_collate_func = CollateFunc(**test_collate_conf, raw_wav=raw_wav)
    dataset_conf = configs.get('dataset_conf', {})
    dataset_conf['batch_size'] = args.batch_size
    dataset_conf['batch_type'] = 'static'
    dataset_conf['sort'] = False
    test_dataset = AudioDataset(args.test_data,
                                **dataset_conf,
                                raw_wav=raw_wav)
    test_data_loader = DataLoader(test_dataset,
                                  collate_fn=test_collate_func,
                                  shuffle=False,
                                  batch_size=1,
                                  num_workers=0)

    # Init asr model from configs
    model = init_asr_model(configs)
Exemple #3
0
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring'
                     ] and args.batch_size > 1:
        logging.fatal(
            'decoding mode {} must be running with batch_size == 1'.format(
                args.mode))
        sys.exit(1)

    with open(args.config, 'r') as fin:
        configs = yaml.load(fin)

    # Init dataset and data loader
    test_collate_conf = copy.copy(configs['collate_conf'])
    test_collate_conf['spec_aug'] = False
    test_collate_func = CollateFunc(**test_collate_conf, cmvn=args.cmvn)
    dataset_conf = configs.get('dataset_conf', {})
    dataset_conf['batch_size'] = args.batch_size
    dataset_conf['sort'] = False
    test_dataset = AudioDataset(args.test_data, **dataset_conf)
    test_data_loader = DataLoader(test_dataset,
                                  collate_fn=test_collate_func,
                                  shuffle=False,
                                  batch_size=1,
                                  num_workers=0)

    # Init transformer model
    input_dim = test_dataset.input_dim
    vocab_size = test_dataset.output_dim
    encoder_type = configs.get('encoder', 'conformer')
    if encoder_type == 'conformer':
Exemple #4
0
            'alignment mode must be running with batch_size == 1')
        sys.exit(1)

    with open(args.config, 'r') as fin:
        configs = yaml.load(fin)

    raw_wav = configs['raw_wav']
    # Init dataset and data loader
    ali_collate_conf = copy.deepcopy(configs['collate_conf'])
    ali_collate_conf['spec_aug'] = False
    ali_collate_conf['spec_sub'] = False
    ali_collate_conf['feature_dither'] = False
    ali_collate_conf['speed_perturb'] = False
    if raw_wav:
        ali_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0
    ali_collate_func = CollateFunc(**ali_collate_conf,
                                   raw_wav=raw_wav)
    dataset_conf = configs.get('dataset_conf', {})
    dataset_conf['batch_size'] = args.batch_size
    dataset_conf['batch_type'] = 'static'
    dataset_conf['sort'] = False
    ali_dataset = AudioDataset(args.input_file, **dataset_conf, raw_wav=raw_wav)
    ali_data_loader = DataLoader(ali_dataset,
                                 collate_fn=ali_collate_func,
                                 shuffle=False,
                                 batch_size=1,
                                 num_workers=0)

    # Init asr model from configs
    model = init_asr_model(configs)

    load_checkpoint(model, args.checkpoint)
Exemple #5
0
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    # Set random seed
    torch.manual_seed(777)

    with open(args.config, 'r') as fin:
        configs = yaml.load(fin)

    distributed = args.world_size > 1

    raw_wav = configs['raw_wav']

    train_collate_func = CollateFunc(**configs['collate_conf'],
                                     raw_wav=raw_wav,
                                     cmvn=args.cmvn)

    cv_collate_conf = copy.deepcopy(configs['collate_conf'])
    # no augmenation on cv set
    cv_collate_conf['spec_aug'] = False
    if raw_wav:
        cv_collate_conf['feature_dither'] = 0.0
        cv_collate_conf['speed_perturb'] = False
        cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0
    cv_collate_func = CollateFunc(**cv_collate_conf,
                                  raw_wav=raw_wav,
                                  cmvn=args.cmvn)

    dataset_conf = configs.get('dataset_conf', {})
    train_dataset = AudioDataset(args.train_data,
Exemple #6
0
                args.mode))
        sys.exit(1)

    with open(args.config, 'r') as fin:
        configs = yaml.load(fin)

    raw_wav = configs['raw_wav']
    # Init dataset and data loader
    # Init dataset and data loader
    test_collate_conf = copy.deepcopy(configs['collate_conf'])
    test_collate_conf['spec_aug'] = False
    test_collate_conf['feature_dither'] = False
    test_collate_conf['speed_perturb'] = False
    test_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0
    test_collate_func = CollateFunc(**test_collate_conf,
                                    raw_wav=raw_wav,
                                    cmvn=args.cmvn)
    dataset_conf = configs.get('dataset_conf', {})
    dataset_conf['batch_size'] = args.batch_size
    dataset_conf['sort'] = False
    test_dataset = AudioDataset(args.test_data, **dataset_conf, raw_wav=raw_wav)
    test_data_loader = DataLoader(test_dataset,
                                  collate_fn=test_collate_func,
                                  shuffle=False,
                                  batch_size=1,
                                  num_workers=0)

    # Init transformer model
    if raw_wav:
        input_dim = configs['collate_conf']['feature_extraction_conf']['mel_bins']
    else:
Exemple #7
0
    args = parser.parse_args()

    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    # Set random seed
    torch.manual_seed(777)

    with open(args.config, 'r') as fin:
        configs = yaml.load(fin)

    distributed = args.world_size > 1

    # Init dataset and data loader
    collate_func = CollateFunc(**configs['collate_conf'],
                               **configs['spec_aug_conf'],
                               cmvn=args.cmvn)
    cv_collate_conf = copy.copy(configs['collate_conf'])
    cv_collate_conf['spec_aug'] = False
    cv_collate_func = CollateFunc(**cv_collate_conf, cmvn=args.cmvn)
    dataset_conf = configs.get('dataset_conf', {})
    train_dataset = AudioDataset(args.train_data, **dataset_conf)
    cv_dataset = AudioDataset(args.cv_data, **dataset_conf)

    if distributed:
        logging.info('training on multiple gpu, this gpu {}'.format(args.gpu))
        dist.init_process_group(args.dist_backend,
                                init_method=args.init_method,
                                world_size=args.world_size,
                                rank=args.rank)
        train_sampler = torch.utils.data.distributed.DistributedSampler(