Exemplo n.º 1
0
def tune_from_args(args):
    # Disable some of the more verbose logging statements
    logging.getLogger('asr.common.params').disabled = True
    logging.getLogger('asr.common.registrable').disabled = True

    # Load from archive
    _, weights_file = load_archive(args.serialization_dir, args.overrides,
                                   args.weights_file)

    params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME),
                         args.overrides)

    prepare_environment(params)

    # Try to use the validation dataset reader if there is one - otherwise fall back
    # to the default dataset_reader used for both training and validation.
    dataset_params = params.pop('val_dataset', params.get('dataset_reader'))

    logger.info("Reading evaluation data from %s", args.input_file)
    dataset_params['manifest_filepath'] = args.input_file
    dataset = datasets.from_params(dataset_params)

    if os.path.exists(os.path.join(args.serialization_dir, "alphabet")):
        alphabet = Alphabet.from_file(
            os.path.join(args.serialization_dir, "alphabet", "tokens"))
    else:
        alphabet = Alphabet.from_params(params.pop("alphabet", {}))

    logits_dir = os.path.join(args.serialization_dir, 'logits')
    os.makedirs(logits_dir, exist_ok=True)

    basename = os.path.splitext(os.path.split(args.input_file)[1])[0]
    logits_file = os.path.join(logits_dir, basename + '.pth')

    if not os.path.exists(logits_file):
        model = models.from_params(alphabet=alphabet,
                                   params=params.pop('model'))
        model.load_state_dict(
            torch.load(weights_file,
                       map_location=lambda storage, loc: storage)['model'])
        model.eval()

        decoder = GreedyCTCDecoder(alphabet)

        loader_params = params.pop("val_data_loader",
                                   params.get("data_loader"))
        batch_sampler = samplers.BucketingSampler(dataset,
                                                  batch_size=args.batch_size)
        loader = loaders.from_params(loader_params,
                                     dataset=dataset,
                                     batch_sampler=batch_sampler)

        logger.info(f'Logits file `{logits_file}` not found. Generating...')

        with torch.no_grad():
            model.to(args.device)

            logits = []
            for batch in tqdm.tqdm(loader):
                sample, target, sample_lengths, target_lengths = batch
                sample = sample.to(args.device)
                sample_lengths = sample_lengths.to(args.device)

                output, output_lengths = model(sample, sample_lengths)

                output = output.to('cpu')

                references = decoder.tensor2str(target, target_lengths)

                logits.extend((o[:l, ...], r) for o, l, r in zip(
                    output.to('cpu'), output_lengths, references))

                del sample, sample_lengths, output

            torch.save(logits, logits_file)

        del model

    tune_dir = os.path.join(args.serialization_dir, 'tune')
    os.makedirs(tune_dir, exist_ok=True)

    params_grid = list(
        product(
            torch.linspace(args.alpha_from, args.alpha_to, args.alpha_steps),
            torch.linspace(args.beta_from, args.beta_to, args.beta_steps)))

    print(
        'Scheduling {} jobs for alphas=linspace({}, {}, {}), betas=linspace({}, {}, {})'
        .format(len(params_grid), args.alpha_from, args.alpha_to,
                args.alpha_steps, args.beta_from, args.beta_to,
                args.beta_steps))

    # start worker processes
    logger.info(
        f"Using {args.num_workers} processes and {args.lm_workers} for each CTCDecoder."
    )
    extract_start = default_timer()

    p = Pool(args.num_workers, init, [
        logits_file, alphabet, args.lm_path, args.cutoff_top_n,
        args.cutoff_prob, args.beam_width, args.lm_workers
    ])

    scores = []
    best_wer = float('inf')
    with tqdm.tqdm(p.imap(tune_step, params_grid),
                   total=len(params_grid),
                   desc='Grid search') as pbar:
        for params in pbar:
            alpha, beta, wer, cer = params
            scores.append([alpha, beta, wer, cer])

            if wer < best_wer:
                best_wer = wer
                pbar.set_postfix(alpha=alpha, beta=beta, wer=wer, cer=cer)

    logger.info(
        f"Finished {len(params_grid)} processes in {default_timer() - extract_start:.1f}s"
    )

    df = pd.DataFrame(scores, columns=['alpha', 'beta', 'wer', 'cer'])
    df.to_csv(os.path.join(tune_dir, basename + '.csv'), index=False)
Exemplo n.º 2
0
def evaluate_from_args(args):
    # Disable some of the more verbose logging statements
    logging.getLogger('asr.common.params').disabled = True
    logging.getLogger('asr.common.registrable').disabled = True

    # Load from archive
    _, weights_file = load_archive(args.serialization_dir, args.overrides,
                                   args.weights_file)

    params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME),
                         args.overrides)

    prepare_environment(params)

    # Try to use the validation dataset reader if there is one - otherwise fall back
    # to the default dataset_reader used for both training and validation.
    dataset_params = params.pop('val_dataset', params.get('dataset_reader'))

    logger.info("Reading evaluation data from %s", args.input_file)
    dataset_params['manifest_filepath'] = args.input_file
    dataset = datasets.from_params(dataset_params)

    if os.path.exists(os.path.join(args.serialization_dir, "alphabet")):
        alphabet = Alphabet.from_file(
            os.path.join(args.serialization_dir, "alphabet", "tokens"))
    else:
        alphabet = Alphabet.from_params(params.pop("alphabet", {}))

    logits_dir = os.path.join(args.serialization_dir, 'logits')
    os.makedirs(logits_dir, exist_ok=True)

    basename = os.path.splitext(os.path.split(args.input_file)[1])[0]
    print(basename)
    logits_file = os.path.join(logits_dir, basename + '.pth')

    if not os.path.exists(logits_file):
        model = models.from_params(alphabet=alphabet,
                                   params=params.pop('model'))
        model.load_state_dict(
            torch.load(weights_file,
                       map_location=lambda storage, loc: storage)['model'])
        model.eval()

        decoder = GreedyCTCDecoder(alphabet)

        loader_params = params.pop("val_data_loader",
                                   params.get("data_loader"))
        batch_sampler = samplers.BucketingSampler(dataset,
                                                  batch_size=args.batch_size)
        loader = loaders.from_params(loader_params,
                                     dataset=dataset,
                                     batch_sampler=batch_sampler)

        logger.info(f'Logits file `{logits_file}` not found. Generating...')

        with torch.no_grad():
            model.to(args.device)

            logits = []
            total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
            for batch in tqdm.tqdm(loader):
                sample, target, sample_lengths, target_lengths = batch
                sample = sample.to(args.device)
                sample_lengths = sample_lengths.to(args.device)

                output, output_lengths = model(sample, sample_lengths)

                output = output.to('cpu')

                references = decoder.tensor2str(target, target_lengths)

                transcripts = decoder.decode(output)[0]

                logits.extend(
                    (o[:l, ...], r)
                    for o, l, r in zip(output, output_lengths, references))

                del sample, sample_lengths, output

                for reference, transcript in zip(references, transcripts):
                    total_wer += decoder.wer(transcript, reference)
                    total_cer += decoder.cer(transcript, reference)
                    num_tokens += float(len(reference.split()))
                    num_chars += float(len(reference))

            torch.save(logits, logits_file)

            wer = float(total_wer) / num_tokens
            cer = float(total_cer) / num_chars

            print(f'WER: {wer:.02%}\nCER: {cer:.02%}')

        del model

    else:
        logger.info(f'Logits file `{logits_file}` already generated.')
Exemplo n.º 3
0
def train_model_from_args(args):

    if args.local_rank == 0 and args.prev_output_dir is not None:
        logger.info('Copying results from {} to {}...'.format(args.prev_output_dir, args.serialization_dir))

        copy_tree(args.prev_output_dir, args.serialization_dir, update=True, verbose=True)

    if not os.path.isfile(args.param_path):
        raise ConfigurationError(f'Parameters file {args.param_path} not found.')

    logger.info(f'Loading experiment from {args.param_path} with overrides `{args.overrides}`.')

    params = Params.load(args.param_path, args.overrides)

    prepare_environment(params)

    logger.info(args.local_rank)
    if args.local_rank == 0:
        create_serialization_dir(params, args.serialization_dir, args.reset)

    if args.distributed:
        logger.info(f'World size: {dist.get_world_size()} | Rank {dist.get_rank()} | ' f'Local Rank {args.local_rank}')
        dist.barrier()

    prepare_global_logging(args.serialization_dir, local_rank=args.local_rank, verbosity=args.verbosity)

    if args.local_rank == 0:
        params.save(os.path.join(args.serialization_dir, CONFIG_NAME))

    loaders = loaders_from_params(params,
                                  distributed=args.distributed,
                                  world_size=args.world_size,
                                  first_epoch=args.first_epoch)

    if os.path.exists(os.path.join(args.serialization_dir, "alphabet")):
        alphabet = Alphabet.from_file(os.path.join(args.serialization_dir, "alphabet"))
    else:
        alphabet = Alphabet.from_params(params.pop("alphabet", {}))

    alphabet.save_to_files(os.path.join(args.serialization_dir, "alphabet"))

    loss = losses.from_params(params.pop('loss'))
    model = models.from_params(alphabet=alphabet, params=params.pop('model'))

    trainer_params = params.pop("trainer")
    if args.fine_tune:
        _, archive_weight_file = models.load_archive(args.fine_tune)

        archive_weights = torch.load(archive_weight_file, map_location=lambda storage, loc: storage)['model']

        # Avoiding initializing from archive some weights
        no_ft_regex = trainer_params.pop("no_ft", ())

        finetune_weights = {}
        random_weights = []
        for name, parameter in archive_weights.items():
            if any(re.search(regex, name) for regex in no_ft_regex):
                random_weights.append(name)
                continue
            finetune_weights[name] = parameter

        logger.info(f'Loading the following weights from archive {args.fine_tune}:')
        logger.info(','.join(finetune_weights.keys()))
        logger.info(f'The following weights are at random:')
        logger.info(','.join(random_weights))

        model.load_state_dict(finetune_weights, strict=False)

    # Freezing some parameters
    freeze_params(model, trainer_params.pop('no_grad', ()))

    trainer = Trainer(args.serialization_dir,
                      trainer_params,
                      model,
                      loss,
                      alphabet,
                      local_rank=args.local_rank,
                      world_size=args.world_size,
                      sync_bn=args.sync_bn,
                      opt_level=args.opt_level,
                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                      loss_scale=args.loss_scale)

    try:
        trainer.run(loaders['train'], val_loader=loaders.get('val'), num_epochs=trainer_params['num_epochs'])
    except KeyboardInterrupt:
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(args.serialization_dir, models.DEFAULT_WEIGHTS)):
            logging.info("Training interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
        raise