def tune_from_args(args): # Disable some of the more verbose logging statements logging.getLogger('asr.common.params').disabled = True logging.getLogger('asr.common.registrable').disabled = True # Load from archive _, weights_file = load_archive(args.serialization_dir, args.overrides, args.weights_file) params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME), args.overrides) prepare_environment(params) # Try to use the validation dataset reader if there is one - otherwise fall back # to the default dataset_reader used for both training and validation. dataset_params = params.pop('val_dataset', params.get('dataset_reader')) logger.info("Reading evaluation data from %s", args.input_file) dataset_params['manifest_filepath'] = args.input_file dataset = datasets.from_params(dataset_params) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file( os.path.join(args.serialization_dir, "alphabet", "tokens")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) logits_dir = os.path.join(args.serialization_dir, 'logits') os.makedirs(logits_dir, exist_ok=True) basename = os.path.splitext(os.path.split(args.input_file)[1])[0] logits_file = os.path.join(logits_dir, basename + '.pth') if not os.path.exists(logits_file): model = models.from_params(alphabet=alphabet, params=params.pop('model')) model.load_state_dict( torch.load(weights_file, map_location=lambda storage, loc: storage)['model']) model.eval() decoder = GreedyCTCDecoder(alphabet) loader_params = params.pop("val_data_loader", params.get("data_loader")) batch_sampler = samplers.BucketingSampler(dataset, batch_size=args.batch_size) loader = loaders.from_params(loader_params, dataset=dataset, batch_sampler=batch_sampler) logger.info(f'Logits file `{logits_file}` not found. Generating...') with torch.no_grad(): model.to(args.device) logits = [] for batch in tqdm.tqdm(loader): sample, target, sample_lengths, target_lengths = batch sample = sample.to(args.device) sample_lengths = sample_lengths.to(args.device) output, output_lengths = model(sample, sample_lengths) output = output.to('cpu') references = decoder.tensor2str(target, target_lengths) logits.extend((o[:l, ...], r) for o, l, r in zip( output.to('cpu'), output_lengths, references)) del sample, sample_lengths, output torch.save(logits, logits_file) del model tune_dir = os.path.join(args.serialization_dir, 'tune') os.makedirs(tune_dir, exist_ok=True) params_grid = list( product( torch.linspace(args.alpha_from, args.alpha_to, args.alpha_steps), torch.linspace(args.beta_from, args.beta_to, args.beta_steps))) print( 'Scheduling {} jobs for alphas=linspace({}, {}, {}), betas=linspace({}, {}, {})' .format(len(params_grid), args.alpha_from, args.alpha_to, args.alpha_steps, args.beta_from, args.beta_to, args.beta_steps)) # start worker processes logger.info( f"Using {args.num_workers} processes and {args.lm_workers} for each CTCDecoder." ) extract_start = default_timer() p = Pool(args.num_workers, init, [ logits_file, alphabet, args.lm_path, args.cutoff_top_n, args.cutoff_prob, args.beam_width, args.lm_workers ]) scores = [] best_wer = float('inf') with tqdm.tqdm(p.imap(tune_step, params_grid), total=len(params_grid), desc='Grid search') as pbar: for params in pbar: alpha, beta, wer, cer = params scores.append([alpha, beta, wer, cer]) if wer < best_wer: best_wer = wer pbar.set_postfix(alpha=alpha, beta=beta, wer=wer, cer=cer) logger.info( f"Finished {len(params_grid)} processes in {default_timer() - extract_start:.1f}s" ) df = pd.DataFrame(scores, columns=['alpha', 'beta', 'wer', 'cer']) df.to_csv(os.path.join(tune_dir, basename + '.csv'), index=False)
def evaluate_from_args(args): # Disable some of the more verbose logging statements logging.getLogger('asr.common.params').disabled = True logging.getLogger('asr.common.registrable').disabled = True # Load from archive _, weights_file = load_archive(args.serialization_dir, args.overrides, args.weights_file) params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME), args.overrides) prepare_environment(params) # Try to use the validation dataset reader if there is one - otherwise fall back # to the default dataset_reader used for both training and validation. dataset_params = params.pop('val_dataset', params.get('dataset_reader')) logger.info("Reading evaluation data from %s", args.input_file) dataset_params['manifest_filepath'] = args.input_file dataset = datasets.from_params(dataset_params) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file( os.path.join(args.serialization_dir, "alphabet", "tokens")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) logits_dir = os.path.join(args.serialization_dir, 'logits') os.makedirs(logits_dir, exist_ok=True) basename = os.path.splitext(os.path.split(args.input_file)[1])[0] print(basename) logits_file = os.path.join(logits_dir, basename + '.pth') if not os.path.exists(logits_file): model = models.from_params(alphabet=alphabet, params=params.pop('model')) model.load_state_dict( torch.load(weights_file, map_location=lambda storage, loc: storage)['model']) model.eval() decoder = GreedyCTCDecoder(alphabet) loader_params = params.pop("val_data_loader", params.get("data_loader")) batch_sampler = samplers.BucketingSampler(dataset, batch_size=args.batch_size) loader = loaders.from_params(loader_params, dataset=dataset, batch_sampler=batch_sampler) logger.info(f'Logits file `{logits_file}` not found. Generating...') with torch.no_grad(): model.to(args.device) logits = [] total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0 for batch in tqdm.tqdm(loader): sample, target, sample_lengths, target_lengths = batch sample = sample.to(args.device) sample_lengths = sample_lengths.to(args.device) output, output_lengths = model(sample, sample_lengths) output = output.to('cpu') references = decoder.tensor2str(target, target_lengths) transcripts = decoder.decode(output)[0] logits.extend( (o[:l, ...], r) for o, l, r in zip(output, output_lengths, references)) del sample, sample_lengths, output for reference, transcript in zip(references, transcripts): total_wer += decoder.wer(transcript, reference) total_cer += decoder.cer(transcript, reference) num_tokens += float(len(reference.split())) num_chars += float(len(reference)) torch.save(logits, logits_file) wer = float(total_wer) / num_tokens cer = float(total_cer) / num_chars print(f'WER: {wer:.02%}\nCER: {cer:.02%}') del model else: logger.info(f'Logits file `{logits_file}` already generated.')
def train_model_from_args(args): if args.local_rank == 0 and args.prev_output_dir is not None: logger.info('Copying results from {} to {}...'.format(args.prev_output_dir, args.serialization_dir)) copy_tree(args.prev_output_dir, args.serialization_dir, update=True, verbose=True) if not os.path.isfile(args.param_path): raise ConfigurationError(f'Parameters file {args.param_path} not found.') logger.info(f'Loading experiment from {args.param_path} with overrides `{args.overrides}`.') params = Params.load(args.param_path, args.overrides) prepare_environment(params) logger.info(args.local_rank) if args.local_rank == 0: create_serialization_dir(params, args.serialization_dir, args.reset) if args.distributed: logger.info(f'World size: {dist.get_world_size()} | Rank {dist.get_rank()} | ' f'Local Rank {args.local_rank}') dist.barrier() prepare_global_logging(args.serialization_dir, local_rank=args.local_rank, verbosity=args.verbosity) if args.local_rank == 0: params.save(os.path.join(args.serialization_dir, CONFIG_NAME)) loaders = loaders_from_params(params, distributed=args.distributed, world_size=args.world_size, first_epoch=args.first_epoch) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file(os.path.join(args.serialization_dir, "alphabet")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) alphabet.save_to_files(os.path.join(args.serialization_dir, "alphabet")) loss = losses.from_params(params.pop('loss')) model = models.from_params(alphabet=alphabet, params=params.pop('model')) trainer_params = params.pop("trainer") if args.fine_tune: _, archive_weight_file = models.load_archive(args.fine_tune) archive_weights = torch.load(archive_weight_file, map_location=lambda storage, loc: storage)['model'] # Avoiding initializing from archive some weights no_ft_regex = trainer_params.pop("no_ft", ()) finetune_weights = {} random_weights = [] for name, parameter in archive_weights.items(): if any(re.search(regex, name) for regex in no_ft_regex): random_weights.append(name) continue finetune_weights[name] = parameter logger.info(f'Loading the following weights from archive {args.fine_tune}:') logger.info(','.join(finetune_weights.keys())) logger.info(f'The following weights are at random:') logger.info(','.join(random_weights)) model.load_state_dict(finetune_weights, strict=False) # Freezing some parameters freeze_params(model, trainer_params.pop('no_grad', ())) trainer = Trainer(args.serialization_dir, trainer_params, model, loss, alphabet, local_rank=args.local_rank, world_size=args.world_size, sync_bn=args.sync_bn, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale) try: trainer.run(loaders['train'], val_loader=loaders.get('val'), num_epochs=trainer_params['num_epochs']) except KeyboardInterrupt: # if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(args.serialization_dir, models.DEFAULT_WEIGHTS)): logging.info("Training interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") raise