def summary(input_path, lang): lang = datasets.Language(lang) transcript = json.load(open(input_path)) for t in transcript: hyp, ref = map(lang.normalize_text, [t['hyp'], t['ref']]) t['cer'] = t.get('cer', metrics.cer(hyp, ref)) t['wer'] = t.get('wer', metrics.wer(hyp, ref)) cer_, wer_ = [ torch.tensor([t[k] for t in transcript]) for k in ['cer', 'wer'] ] cer_avg, wer_avg = float(cer_.mean()), float(wer_.mean()) print(f'CER: {cer_avg:.02f} | WER: {wer_avg:.02f}') loss_ = torch.tensor([t.get('loss', 0) for t in transcript]) loss_ = loss_[~(torch.isnan(loss_) | torch.isinf(loss_))] #min, max, steps = 0.0, 2.0, 20 #bins = torch.linspace(min, max, steps = steps) #hist = torch.histc(loss_, bins = steps, min = min, max = max) #for b, h in zip(bins.tolist(), hist.tolist()): # print(f'{b:.02f}\t{h:.0f}') plt.figure(figsize=(8, 4)) plt.suptitle(os.path.basename(input_path)) plt.subplot(211) plt.title('cer PDF') #plt.hist(cer_, range = (0.0, 1.2), bins = 20, density = True) seaborn.distplot(cer_, bins=20, hist=True) plt.xlim(0, 1) plt.subplot(212) plt.title('cer CDF') plt.hist(cer_, bins=20, density=True, cumulative=True) plt.xlim(0, 1) plt.xticks(torch.arange(0, 1.01, 0.1)) plt.grid(True) #plt.subplot(223) #plt.title('loss PDF') #plt.hist(loss_, range = (0.0, 2.0), bins = 20, density = True) #seaborn.distplot(loss_, bins = 20, hist = True) #plt.xlim(0, 3) #plt.subplot(224) #plt.title('loss CDF') #plt.hist(loss_, bins = 20, density = True, cumulative = True) #plt.grid(True) plt.subplots_adjust(hspace=0.4) plt.savefig(input_path + '.png', dpi=150)
def tabulate( experiments_dir, experiment_id, entropy, loss, cer10, cer15, cer20, cer30, cer40, cer50, per, wer, json_, bpe, der, lang ): # TODO: bring back custom name to the filtration process, or remove filtration by labels_name entirely. labels = datasets.Labels(lang=datasets.Language(lang), name='char') res = collections.defaultdict(list) experiment_dir = os.path.join(experiments_dir, experiment_id) for f in sorted(glob.glob(os.path.join(experiment_dir, f'transcripts_*.json'))): eidx = f.find('epoch') iteration = f[eidx:].replace('.json', '') val_dataset_name = f[f.find('transcripts_') + len('transcripts_'):eidx] checkpoint = os.path.join(experiment_dir, 'checkpoint_' + f[eidx:].replace('.json', '.pt')) if not json_ else f metric = 'wer' if wer else 'entropy' if entropy else 'loss' if loss else 'per' if per else 'der' if der else 'cer' val = torch.tensor([j[metric] for j in json.load(open(f)) if j['labels_name'] == labels.name] or [0.0]) val = val[~(torch.isnan(val) | torch.isinf(val))] if cer10 or cer20 or cer30 or cer40 or cer50: val = (val < 0.1 * [False, cer10, cer20, cer30, cer40, cer50].index(True)).float() if cer15: val = (val < 0.15).float() res[iteration].append((val_dataset_name, float(val.mean()), checkpoint)) val_dataset_names = sorted(set(val_dataset_name for r in res.values() for val_dataset_name, cer, checkpoint in r)) print('iteration\t' + '\t'.join(val_dataset_names)) for iteration, r in res.items(): cers = {val_dataset_name: f'{cer:.04f}' for val_dataset_name, cer, checkpoint in r} print( f'{iteration}\t' + '\t'.join(cers.get(val_dataset_name, '') for val_dataset_name in val_dataset_names) + f'\t{r[-1][-1]}' )
def normalize(input_path, lang, dry = True): lang = datasets.Language(lang) labels = datasets.Labels(lang) for transcript_path in input_path: with open(transcript_path) as f: transcript = json.load(f) for t in transcript: if 'ref' in t: t['ref'] = labels.postprocess_transcript(lang.normalize_text(t['ref'])) if 'hyp' in t: t['hyp'] = labels.postprocess_transcript(lang.normalize_text(t['hyp'])) if 'ref' in t and 'hyp' in t: t['cer'] = t['cer'] if 'cer' in t else metrics.cer(t['hyp'], t['ref']) t['wer'] = t['wer'] if 'wer' in t else metrics.wer(t['hyp'], t['ref']) if not dry: json.dump(transcript, open(transcript_path, 'w'), ensure_ascii = False, indent = 2, sort_keys = True) else: return transcript
def setup(args): torch.set_grad_enabled(False) checkpoint = torch.load(args.checkpoint, map_location='cpu') args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = map( checkpoint['args'].get, [ 'sample_rate', 'window_size', 'window_stride', 'window', 'num_input_features' ]) frontend = models.LogFilterBankFrontend(args.num_input_features, args.sample_rate, args.window_size, args.window_stride, args.window, eps=1e-6) labels = datasets.Labels(datasets.Language(checkpoint['args']['lang']), name='char') model = getattr(models, args.model or checkpoint['args']['model'])( args.num_input_features, [len(labels)], frontend=frontend, dict=lambda logits, log_probs, olen, **kwargs: (logits[0], olen[0])) model.load_state_dict(checkpoint['model_state_dict'], strict=False) model = model.to(args.device) model.eval() model.fuse_conv_bn_eval() if args.device != 'cpu': model, *_ = models.data_parallel_and_autocast(model, opt_level=args.fp16) decoder = decoders.GreedyDecoder( ) if args.decoder == 'GreedyDecoder' else decoders.BeamSearchDecoder( labels, lm_path=args.lm, beam_width=args.beam_width, beam_alpha=args.beam_alpha, beam_beta=args.beam_beta, num_workers=args.num_workers, topk=args.decoder_topk) return labels, frontend, model, decoder
def lserrorwords(input_path, output_path, comment_path, freq_path, sortdesc, sortasc, comment_filter, lang): lang = datasets.Language(lang) regex = r'[ ]+-[ ]*', '-' freq = { splitted[0]: int(splitted[-1]) for line in open(freq_path) for splitted in [re.sub(regex[0], regex[1], line).split()] } if freq_path else {} comment = { splitted[0]: splitted[-1].strip() for line in open(comment_path) for splitted in [line.split(',')] if '#' not in line and len(splitted) > 1 } if comment_path else {} transcript = json.load(open(input_path)) transcript = list(filter(lambda t: [(w.get('type') or w.get('error_tag')) for w in t['words']].count('missing_ref') <= 2, transcript)) stem = lambda word: (lang.stem(word), len(word)) words_ok = [w['ref'].replace(metrics.placeholder, '') for t in transcript for w in t['words'] if (w.get('type') or w.get('error_tag')) == 'ok'] words_error = [ w['ref'].replace(metrics.placeholder, '') for t in transcript for w in t['words'] if (w.get('type') or w.get('error_tag')) not in ['ok', 'missing_ref'] ] words_error = set(ref for ref in words_error if len(ref) > 1) usage = { k: [tup[1] for tup in g] for k, g in itertools.groupby( sorted([(w['ref'].replace(metrics.placeholder, ''), t) for t in transcript for w in t['words']], key = lambda t: t[0]), key = lambda t: t[0] ) } words_ok_counter = collections.Counter(map(stem, words_ok)) words_error_counter = collections.Counter(map(stem, words_error)) group = lambda c: stem(c[0]) #comment = {k : ';'.join(set(c[1] for c in g if c[1])) for k, g in itertools.groupby(sorted(comment.items(), key = group), key = group)} #words = {ref : (ref, words_error_counter[l] - words_ok_counter[l], words_error_counter[l], words_ok_counter[l], freq.get(ref, 0), (usage.get(ref, []) + usage_placeholder)[0], (usage.get(ref, []) + usage_placeholder)[1], comment.get(ref, '')) for ref in words_error for l in [stem(ref)]} words = { ref: ( ref, words_error_counter[l] - words_ok_counter[l], words_error_counter[l], words_ok_counter[l], freq.get(ref, 0), usage.get(ref, [{}])[0]['audio_name'], usage.get(ref, [{}])[0]['ref'], comment.get(ref, '') ) for ref in words_error for l in [stem(ref)] } key = sortdesc or sortasc words = list( sorted( words.values(), key = lambda t: (t[-5] if key == 'diff' else (-t[2] - t[3], t[5]), t[0]), reverse = bool(sortdesc) ) ) words = filter(lambda tup: comment_filter in tup[-1], words) f = open(output_path, 'w') if output_path.endswith('.csv'): f.write('#word,diff,err,ok,freq,audioname,usage,comment\n' + '\n'.join(','.join(map(str, t)) for t in words)) elif output_path.endswith('.json'): json.dump([ dict(audio_name = audio_name, before = word, after = '') for word, diff, err, ok, freq, audio_name, usage, comment in words ], f, ensure_ascii = False, indent = 2, sort_keys = True) print(output_path)
def logits(lang, logits, audio_name = None, MAX_ENTROPY = 1.0): good_audio_name = set(map(str.strip, open(audio_name[0])) if os.path.exists(audio_name[0]) else audio_name) if audio_name is not None else [] labels = datasets.Labels(datasets.Language(lang)) decoder = decoders.GreedyDecoder() tick_params = lambda ax, labelsize = 2.5, length = 0, **kwargs: ax.tick_params(axis = 'both', which = 'both', labelsize = labelsize, length = length, **kwargs) or [ax.set_linewidth(0) for ax in ax.spines.values()] logits_path = logits + '.html' html = open(logits_path, 'w') html.write('<html><head>' + meta_charset + f'</head><body><script>{play_script}{onclick_img_script}</script>') for i, t in enumerate(torch.load(logits)): audio_path, logits = t['audio_path'], t['logits'] words = t.get('words', [t]) y = t.get('y', torch.zeros(1, 0, dtype = torch.long)) begin = t.get('begin', '') end = t.get('end', '') audio_name = transcripts.audio_name(audio_path) extra_metrics = dict(cer = t['cer']) if 'cer' in t else {} if good_audio_name and audio_name not in good_audio_name: continue log_probs = F.log_softmax(logits, dim = 0) entropy = models.entropy(log_probs, dim = 0, sum = False) log_probs_ = F.log_softmax(logits[:-1], dim = 0) entropy_ = models.entropy(log_probs_, dim = 0, sum = False) margin = models.margin(log_probs, dim = 0) #energy = features.exp().sum(dim = 0)[::2] plt.figure(figsize = (6, 2)) ax = plt.subplot(211) plt.imshow(logits, aspect = 'auto') plt.xlim(0, logits.shape[-1] - 1) #plt.yticks([]) plt.axis('off') tick_params(plt.gca()) #plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95) plt.subplot(212, sharex = ax) prob_top1, prob_top2 = log_probs.exp().topk(2, dim = 0).values plt.hlines(1.0, 0, entropy.shape[-1] - 1, linewidth = 0.2) artist_prob_top1, = plt.plot(prob_top1, 'b', linewidth = 0.3) artist_prob_top2, = plt.plot(prob_top2, 'g', linewidth = 0.3) artist_entropy, = plt.plot(entropy, 'r', linewidth = 0.3) artist_entropy_, = plt.plot(entropy_, 'yellow', linewidth = 0.3) plt.legend([artist_entropy, artist_entropy_, artist_prob_top1, artist_prob_top2], ['entropy', 'entropy, no blank', 'top1 prob', 'top2 prob'], loc = 1, fontsize = 'xx-small', frameon = False) for b, e, v in zip(*models.rle1d(entropy > MAX_ENTROPY)): if bool(v): plt.axvspan(int(b), int(e), color='red', alpha=0.2) plt.ylim(0, 3.0) plt.xlim(0, entropy.shape[-1] - 1) decoded = decoder.decode(log_probs.unsqueeze(0), K = 5)[0] xlabels = list( map( '\n'.join, zip( *[ labels.decode(d, replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False) for d in decoded ] ) ) ) plt.xticks(torch.arange(entropy.shape[-1]), xlabels, fontfamily = 'monospace') tick_params(plt.gca()) if y.numel() > 0: alignment = ctc.alignment( log_probs.unsqueeze(0).permute(2, 0, 1), y.unsqueeze(0).long(), torch.LongTensor([log_probs.shape[-1]]), torch.LongTensor([len(y)]), blank = len(log_probs) - 1 ).squeeze(0) ax = plt.gca().secondary_xaxis('top') ref, ref_ = labels.decode(y.tolist(), replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False), alignment ax.set_xticklabels(ref) ax.set_xticks(ref_) tick_params(ax, colors = 'red') #k = 0 #for i, c in enumerate(ref + ' '): # if c == ' ': # plt.axvspan(ref_[k] - 1, ref_[i - 1] + 1, facecolor = 'gray', alpha = 0.2) # k = i + 1 plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95) buf = io.BytesIO() plt.savefig(buf, format = 'jpg', dpi = 600) plt.close() html.write(f'<h4>{audio_name}') html.write(' | '.join('{k}: {v:.02f}' for k, v in extra_metrics.items())) html.write('</h4>') html.write(fmt_alignment(words)) html.write('<img data-begin="{begin}" data-end="{end}" data-channel="{channel}" onclick="onclick_img(event)" style="width:100%" src="data:image/jpeg;base64,{encoded}"></img>\n'.format(channel = i, begin = begin, end = end, encoded = base64.b64encode(buf.getvalue()).decode())) html.write(fmt_audio(audio_path = audio_path, channel = i)) html.write('<hr/>') html.write('</body></html>') return logits_path
def main(args): checkpoints = [ torch.load(checkpoint_path, map_location='cpu') for checkpoint_path in args.checkpoint ] checkpoint = (checkpoints + [{}])[0] if len(checkpoints) > 1: checkpoint['model_state_dict'] = { k: sum(c['model_state_dict'][k] for c in checkpoints) / len(checkpoints) for k in checkpoint['model_state_dict'] } if args.frontend_checkpoint: frontend_checkpoint = torch.load(args.frontend_checkpoint, map_location='cpu') frontend_extra_args = frontend_checkpoint['args'] frontend_checkpoint = frontend_checkpoint['model'] else: frontend_extra_args = None frontend_checkpoint = None args.experiment_id = args.experiment_id.format( model=args.model, frontend=args.frontend, train_batch_size=args.train_batch_size, optimizer=args.optimizer, lr=args.lr, weight_decay=args.weight_decay, time=time.strftime('%Y-%m-%d_%H-%M-%S'), experiment_name=args.experiment_name, bpe='bpe' if args.bpe else '', train_waveform_transform= f'aug{args.train_waveform_transform[0]}{args.train_waveform_transform_prob or ""}' if args.train_waveform_transform else '', train_feature_transform= f'aug{args.train_feature_transform[0]}{args.train_feature_transform_prob or ""}' if args.train_feature_transform else '').replace('e-0', 'e-').rstrip('_') if checkpoint and 'experiment_id' in checkpoint[ 'args'] and not args.experiment_name: args.experiment_id = checkpoint['args']['experiment_id'] args.experiment_dir = args.experiment_dir.format( experiments_dir=args.experiments_dir, experiment_id=args.experiment_id) os.makedirs(args.experiment_dir, exist_ok=True) if args.log_json: args.log_json = os.path.join(args.experiment_dir, 'log.json') if checkpoint: args.lang, args.model, args.num_input_features, args.sample_rate, args.window, args.window_size, args.window_stride = map( checkpoint['args'].get, [ 'lang', 'model', 'num_input_features', 'sample_rate', 'window', 'window_size', 'window_stride' ]) utils.set_up_root_logger(os.path.join(args.experiment_dir, 'log.txt'), mode='a') logfile_sink = JsonlistSink(args.log_json, mode='a') else: utils.set_up_root_logger(os.path.join(args.experiment_dir, 'log.txt'), mode='w') logfile_sink = JsonlistSink(args.log_json, mode='w') _print = utils.get_root_logger_print() _print('\n', 'Arguments:', args) _print( f'"CUDA_VISIBLE_DEVICES={os.environ.get("CUDA_VISIBLE_DEVICES", default = "")}"' ) _print( f'"CUDA_LAUNCH_BLOCKING={os.environ.get("CUDA_LAUNCH_BLOCKING", default="")}"' ) _print('Experiment id:', args.experiment_id, '\n') if args.dry: return utils.set_random_seed(args.seed) if args.cudnn == 'benchmark': torch.backends.cudnn.benchmark = True lang = datasets.Language(args.lang) #TODO: , candidate_sep = datasets.Labels.candidate_sep normalize_text_config = json.load(open( args.normalize_text_config)) if os.path.exists( args.normalize_text_config) else {} labels = [ datasets.Labels( lang, name='char', normalize_text_config=normalize_text_config) ] + [ datasets.Labels(lang, bpe=bpe, name=f'bpe{i}', normalize_text_config=normalize_text_config) for i, bpe in enumerate(args.bpe) ] frontend = getattr(models, args.frontend)(out_channels=args.num_input_features, sample_rate=args.sample_rate, window_size=args.window_size, window_stride=args.window_stride, window=args.window, dither=args.dither, dither0=args.dither0, stft_mode='conv' if args.onnx else None, extra_args=frontend_extra_args) model = getattr(models, args.model)( num_input_features=args.num_input_features, num_classes=list(map(len, labels)), dropout=args.dropout, decoder_type='bpe' if args.bpe else None, frontend=frontend if args.onnx or args.frontend_in_model else None, **(dict(inplace=False, dict=lambda logits, log_probs, olen, **kwargs: logits[0]) if args.onnx else {})) _print('Model capacity:', int(models.compute_capacity(model, scale=1e6)), 'million parameters\n') if checkpoint: model.load_state_dict(checkpoint['model_state_dict'], strict=False) if frontend_checkpoint: frontend_checkpoint = { 'model.' + name: weight for name, weight in frontend_checkpoint.items() } ##TODO remove after save checkpoint naming fix frontend.load_state_dict(frontend_checkpoint) if args.onnx: torch.set_grad_enabled(False) model.eval() model.to(args.device) model.fuse_conv_bn_eval() if args.fp16: model = models.InputOutputTypeCast(model.to(torch.float16), dtype=torch.float16) waveform_input = torch.rand(args.onnx_sample_batch_size, args.onnx_sample_time, device=args.device) logits = model(waveform_input) torch.onnx.export(model, (waveform_input, ), args.onnx, opset_version=args.onnx_opset, export_params=args.onnx_export_params, do_constant_folding=True, input_names=['x'], output_names=['logits'], dynamic_axes=dict(x={ 0: 'B', 1: 'T' }, logits={ 0: 'B', 2: 't' })) onnxruntime_session = onnxruntime.InferenceSession(args.onnx) if args.verbose: onnxruntime.set_default_logger_severity(0) (logits_, ) = onnxruntime_session.run( None, dict(x=waveform_input.cpu().numpy())) assert torch.allclose(logits.cpu(), torch.from_numpy(logits_), rtol=1e-02, atol=1e-03) #model_def = onnx.load(args.onnx) #import onnx.tools.net_drawer # import GetPydotGraph, GetOpNodeProducer #pydot_graph = GetPydotGraph(model_def.graph, name=model_def.graph.name, rankdir="TB", node_producer=GetOpNodeProducer("docstring", color="yellow", fillcolor="yellow", style="filled")) #pydot_graph.write_dot("pipeline_transpose2x.dot") #os.system('dot -O -Gdpi=300 -Tpng pipeline_transpose2x.dot') # add metadata to model return perf.init_default(loss=dict(K=50, max=1000), memory_cuda_allocated=dict(K=50), entropy=dict(K=4), time_ms_iteration=dict(K=50, max=10_000), lr=dict(K=50, max=1)) val_config = json.load(open(args.val_config)) if os.path.exists( args.val_config) else {} word_tags = json.load(open(args.word_tags)) if os.path.exists( args.word_tags) else {} for word_tag, words in val_config.get('word_tags', {}).items(): word_tags[word_tag] = word_tags.get(word_tag, []) + words vocab = set(map(str.strip, open(args.vocab))) if os.path.exists( args.vocab) else set() error_analyzer = metrics.ErrorAnalyzer( metrics.WordTagger(lang, vocab=vocab, word_tags=word_tags), metrics.ErrorTagger(), val_config.get('error_analyzer', {})) make_transform = lambda name_args, prob: None if not name_args else getattr( transforms, name_args[0])(*name_args[1:]) if prob is None else getattr( transforms, name_args[0])(prob, *name_args[1:] ) if prob > 0 else None val_frontend = models.AugmentationFrontend( frontend, waveform_transform=make_transform(args.val_waveform_transform, args.val_waveform_transform_prob), feature_transform=make_transform(args.val_feature_transform, args.val_feature_transform_prob)) if args.val_waveform_transform_debug_dir: args.val_waveform_transform_debug_dir = os.path.join( args.val_waveform_transform_debug_dir, str(val_frontend.waveform_transform) if isinstance( val_frontend.waveform_transform, transforms.RandomCompose) else val_frontend.waveform_transform.__class__.__name__) os.makedirs(args.val_waveform_transform_debug_dir, exist_ok=True) val_data_loaders = { os.path.basename(val_data_path): torch.utils.data.DataLoader( val_dataset, num_workers=args.num_workers, collate_fn=val_dataset.collate_fn, pin_memory=True, shuffle=False, batch_size=args.val_batch_size, worker_init_fn=datasets.worker_init_fn, timeout=args.timeout if args.num_workers > 0 else 0) for val_data_path in args.val_data_path for val_dataset in [ datasets.AudioTextDataset( val_data_path, labels, args.sample_rate, frontend=val_frontend if not args.frontend_in_model else None, waveform_transform_debug_dir=args. val_waveform_transform_debug_dir, min_duration=args.min_duration, time_padding_multiple=args.batch_time_padding_multiple, pop_meta=True, _print=_print) ] } decoder = [ decoders.GreedyDecoder() if args.decoder == 'GreedyDecoder' else decoders.BeamSearchDecoder(labels[0], lm_path=args.lm, beam_width=args.beam_width, beam_alpha=args.beam_alpha, beam_beta=args.beam_beta, num_workers=args.num_workers, topk=args.decoder_topk) ] + [decoders.GreedyDecoder() for bpe in args.bpe] model.to(args.device) if not args.train_data_path: model.eval() if not args.adapt_bn: model.fuse_conv_bn_eval() if args.device != 'cpu': model, *_ = models.data_parallel_and_autocast( model, opt_level=args.fp16, keep_batchnorm_fp32=args.fp16_keep_batchnorm_fp32) evaluate_model(args, val_data_loaders, model, labels, decoder, error_analyzer) return model.freeze(backbone=args.freeze_backbone, decoder0=args.freeze_decoder, frontend=args.freeze_frontend) train_frontend = models.AugmentationFrontend( frontend, waveform_transform=make_transform(args.train_waveform_transform, args.train_waveform_transform_prob), feature_transform=make_transform(args.train_feature_transform, args.train_feature_transform_prob)) tic = time.time() train_dataset = datasets.AudioTextDataset( args.train_data_path, labels, args.sample_rate, frontend=train_frontend if not args.frontend_in_model else None, min_duration=args.min_duration, max_duration=args.max_duration, time_padding_multiple=args.batch_time_padding_multiple, bucket=lambda example: int( math.ceil(((example[0]['end'] - example[0]['begin']) / args. window_stride + 1) / args.batch_time_padding_multiple)), pop_meta=True, _print=_print) _print('Time train dataset created:', time.time() - tic, 'sec') train_dataset_name = '_'.join(map(os.path.basename, args.train_data_path)) tic = time.time() sampler = datasets.BucketingBatchSampler( train_dataset, batch_size=args.train_batch_size, ) _print('Time train sampler created:', time.time() - tic, 'sec') train_data_loader = torch.utils.data.DataLoader( train_dataset, num_workers=args.num_workers, collate_fn=train_dataset.collate_fn, pin_memory=True, batch_sampler=sampler, worker_init_fn=datasets.worker_init_fn, timeout=args.timeout if args.num_workers > 0 else 0) optimizer = torch.optim.SGD( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov ) if args.optimizer == 'SGD' else torch.optim.AdamW( model.parameters(), lr=args.lr, betas=args.betas, weight_decay=args.weight_decay ) if args.optimizer == 'AdamW' else optimizers.NovoGrad( model.parameters(), lr=args.lr, betas=args.betas, weight_decay=args.weight_decay ) if args.optimizer == 'NovoGrad' else apex.optimizers.FusedNovoGrad( model.parameters(), lr=args.lr, betas=args.betas, weight_decay=args.weight_decay ) if args.optimizer == 'FusedNovoGrad' else None if checkpoint and checkpoint['optimizer_state_dict'] is not None: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if not args.skip_optimizer_reset: optimizers.reset_options(optimizer) scheduler = optimizers.MultiStepLR( optimizer, gamma=args.decay_gamma, milestones=args.decay_milestones ) if args.scheduler == 'MultiStepLR' else optimizers.PolynomialDecayLR( optimizer, power=args.decay_power, decay_steps=len(train_data_loader) * args.decay_epochs, end_lr=args.decay_lr ) if args.scheduler == 'PolynomialDecayLR' else optimizers.NoopLR( optimizer) epoch, iteration = 0, 0 if checkpoint: epoch, iteration = checkpoint['epoch'], checkpoint['iteration'] if args.train_data_path == checkpoint['args']['train_data_path']: sampler.load_state_dict(checkpoint['sampler_state_dict']) if args.iterations_per_epoch and iteration and iteration % args.iterations_per_epoch == 0: sampler.batch_idx = 0 epoch += 1 else: epoch += 1 if args.iterations_per_epoch: epoch_skip_fraction = 1 - args.iterations_per_epoch / len( train_data_loader) assert epoch_skip_fraction < args.max_epoch_skip_fraction, \ f'args.iterations_per_epoch must not skip more than {args.max_epoch_skip_fraction:.1%} of each epoch' if args.device != 'cpu': model, optimizer = models.data_parallel_and_autocast( model, optimizer, opt_level=args.fp16, keep_batchnorm_fp32=args.fp16_keep_batchnorm_fp32) if checkpoint and args.fp16 and checkpoint['amp_state_dict'] is not None: apex.amp.load_state_dict(checkpoint['amp_state_dict']) model.train() tensorboard_dir = os.path.join(args.experiment_dir, 'tensorboard') if checkpoint and args.experiment_name: tensorboard_dir_checkpoint = os.path.join( os.path.dirname(args.checkpoint[0]), 'tensorboard') if os.path.exists(tensorboard_dir_checkpoint ) and not os.path.exists(tensorboard_dir): shutil.copytree(tensorboard_dir_checkpoint, tensorboard_dir) tensorboard = torch.utils.tensorboard.SummaryWriter(tensorboard_dir) tensorboard_sink = TensorboardSink(tensorboard) with open(os.path.join(args.experiment_dir, args.args), 'w') as f: json.dump(vars(args), f, sort_keys=True, ensure_ascii=False, indent=2) with open(os.path.join(args.experiment_dir, args.dump_model_config), 'w') as f: model_config = dict( init_params=models.master_module(model).init_params, model=repr(models.master_module(model))) json.dump(model_config, f, sort_keys=True, ensure_ascii=False, indent=2) tic, toc_fwd, toc_bwd = time.time(), time.time(), time.time() oom_handler = utils.OomHandler(max_retries=args.oom_retries) for epoch in range(epoch, args.epochs): sampler.shuffle(epoch + args.seed_sampler) time_epoch_start = time.time() for batch_idx, (meta, s, x, xlen, y, ylen) in enumerate(train_data_loader, start=sampler.batch_idx): toc_data = time.time() if batch_idx == 0: time_ms_launch_data_loader = (toc_data - tic) * 1000 _print('Time data loader launch @ ', epoch, ':', time_ms_launch_data_loader / 1000, 'sec') lr = optimizer.param_groups[0]['lr'] perf.update(dict(lr=lr)) x, xlen, y, ylen = x.to(args.device, non_blocking=True), xlen.to( args.device, non_blocking=True), y.to( args.device, non_blocking=True), ylen.to(args.device, non_blocking=True) try: #TODO check nan values in tensors, they can break running_stats in bn log_probs, olen, loss = map( model(x, xlen, y=y, ylen=ylen).get, ['log_probs', 'olen', 'loss']) oom_handler.reset() except: if oom_handler.try_recover(model.parameters(), _print=_print): continue else: raise example_weights = ylen[:, 0] loss, loss_cur = (loss * example_weights).mean( ) / args.train_batch_accumulate_iterations, float(loss.mean()) perf.update(dict(loss_BT_normalized=loss_cur)) entropy = float( models.entropy(log_probs[0], olen[0], dim=1).mean()) toc_fwd = time.time() #TODO: inf/nan still corrupts BN stats if not (torch.isinf(loss) or torch.isnan(loss)): if args.fp16: with apex.amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if iteration % args.train_batch_accumulate_iterations == 0: torch.nn.utils.clip_grad_norm_( apex.amp.master_params(optimizer) if args.fp16 else model.parameters(), args.max_norm) optimizer.step() if iteration > 0 and iteration % args.log_iteration_interval == 0: perf.update(utils.compute_memory_stats(), prefix='performance') tensorboard_sink.perf(perf.default(), iteration, train_dataset_name) tensorboard_sink.weight_stats( iteration, model, args.log_weight_distribution) logfile_sink.perf(perf.default(), iteration, train_dataset_name) optimizer.zero_grad() scheduler.step(iteration) perf.update(dict(entropy=entropy)) toc_bwd = time.time() time_ms_data, time_ms_fwd, time_ms_bwd, time_ms_model = map( lambda sec: sec * 1000, [ toc_data - tic, toc_fwd - toc_data, toc_bwd - toc_fwd, toc_bwd - toc_data ]) perf.update(dict(time_ms_data=time_ms_data, time_ms_fwd=time_ms_fwd, time_ms_bwd=time_ms_bwd, time_ms_iteration=time_ms_data + time_ms_model), prefix='performance') perf.update(dict(input_B=x.shape[0], input_T=x.shape[-1]), prefix='performance') print_left = f'{args.experiment_id} | epoch: {epoch:02d} iter: [{batch_idx: >6d} / {len(train_data_loader)} {iteration: >6d}] {"x".join(map(str, x.shape))}' print_right = 'ent: <{avg_entropy:.2f}> loss: {cur_loss_BT_normalized:.2f} <{avg_loss_BT_normalized:.2f}> time: {performance_cur_time_ms_data:.2f}+{performance_cur_time_ms_fwd:4.0f}+{performance_cur_time_ms_bwd:4.0f} <{performance_avg_time_ms_iteration:.0f}> | lr: {cur_lr:.5f}'.format( **perf.default()) _print(print_left, print_right) iteration += 1 sampler.batch_idx += 1 if iteration > 0 and (iteration % args.val_iteration_interval == 0 or iteration == args.iterations): evaluate_model(args, val_data_loaders, model, labels, decoder, error_analyzer, optimizer, sampler, tensorboard_sink, logfile_sink, epoch, iteration) if iteration and args.iterations and iteration >= args.iterations: return if args.iterations_per_epoch and iteration > 0 and iteration % args.iterations_per_epoch == 0: break tic = time.time() sampler.batch_idx = 0 _print('Epoch time', (time.time() - time_epoch_start) / 60, 'minutes') if not args.skip_on_epoch_end_evaluation: evaluate_model(args, val_data_loaders, model, labels, decoder, error_analyzer, optimizer, sampler, tensorboard_sink, logfile_sink, epoch + 1, iteration)
parser.add_argument('-B', type = int, default = 256) parser.add_argument('-T', type = int, default = 5.12) parser.add_argument('--profile-cuda', action = 'store_true') parser.add_argument('--profile-pyprof', action = 'store_true') parser.add_argument('--profile-autograd') parser.add_argument('--data-parallel', action = 'store_true') parser.add_argument('--backward', action = 'store_true') args = parser.parse_args() checkpoint = torch.load(args.checkpoint, map_location = 'cpu') if args.checkpoint else None if checkpoint: args.model, args.lang, args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = map(checkpoint['args'].get, ['model', 'lang', 'sample_rate', 'window_size', 'window_stride', 'window', 'num_input_features']) use_cuda = 'cuda' in args.device labels = datasets.Labels(datasets.Language(args.lang)) if args.onnx: onnxruntime_session = onnxruntime.InferenceSession(args.onnx) model = lambda x: onnxruntime_session.run(None, dict(x = x)) load_batch = lambda x: x.numpy() else: frontend = models.LogFilterBankFrontend( args.num_input_features, args.sample_rate, args.window_size, args.window_stride, args.window, stft_mode = args.stft_mode ) if args.frontend else None