Пример #1
0
def summary(input_path, lang):
    lang = datasets.Language(lang)
    transcript = json.load(open(input_path))
    for t in transcript:
        hyp, ref = map(lang.normalize_text, [t['hyp'], t['ref']])
        t['cer'] = t.get('cer', metrics.cer(hyp, ref))
        t['wer'] = t.get('wer', metrics.wer(hyp, ref))

    cer_, wer_ = [
        torch.tensor([t[k] for t in transcript]) for k in ['cer', 'wer']
    ]
    cer_avg, wer_avg = float(cer_.mean()), float(wer_.mean())
    print(f'CER: {cer_avg:.02f} | WER: {wer_avg:.02f}')

    loss_ = torch.tensor([t.get('loss', 0) for t in transcript])
    loss_ = loss_[~(torch.isnan(loss_) | torch.isinf(loss_))]
    #min, max, steps = 0.0, 2.0, 20
    #bins = torch.linspace(min, max, steps = steps)
    #hist = torch.histc(loss_, bins = steps, min = min, max = max)
    #for b, h in zip(bins.tolist(), hist.tolist()):
    #	print(f'{b:.02f}\t{h:.0f}')

    plt.figure(figsize=(8, 4))
    plt.suptitle(os.path.basename(input_path))
    plt.subplot(211)
    plt.title('cer PDF')
    #plt.hist(cer_, range = (0.0, 1.2), bins = 20, density = True)
    seaborn.distplot(cer_, bins=20, hist=True)
    plt.xlim(0, 1)
    plt.subplot(212)
    plt.title('cer CDF')
    plt.hist(cer_, bins=20, density=True, cumulative=True)
    plt.xlim(0, 1)
    plt.xticks(torch.arange(0, 1.01, 0.1))
    plt.grid(True)

    #plt.subplot(223)
    #plt.title('loss PDF')
    #plt.hist(loss_, range = (0.0, 2.0), bins = 20, density = True)
    #seaborn.distplot(loss_, bins = 20, hist = True)
    #plt.xlim(0, 3)
    #plt.subplot(224)
    #plt.title('loss CDF')
    #plt.hist(loss_, bins = 20, density = True, cumulative = True)
    #plt.grid(True)
    plt.subplots_adjust(hspace=0.4)
    plt.savefig(input_path + '.png', dpi=150)
Пример #2
0
def tabulate(
	experiments_dir,
	experiment_id,
	entropy,
	loss,
	cer10,
	cer15,
	cer20,
	cer30,
	cer40,
	cer50,
	per,
	wer,
	json_,
	bpe,
	der,
	lang
):
	# TODO: bring back custom name to the filtration process, or remove filtration by labels_name entirely.
	labels = datasets.Labels(lang=datasets.Language(lang), name='char')

	res = collections.defaultdict(list)
	experiment_dir = os.path.join(experiments_dir, experiment_id)
	for f in sorted(glob.glob(os.path.join(experiment_dir, f'transcripts_*.json'))):
		eidx = f.find('epoch')
		iteration = f[eidx:].replace('.json', '')
		val_dataset_name = f[f.find('transcripts_') + len('transcripts_'):eidx]
		checkpoint = os.path.join(experiment_dir, 'checkpoint_' + f[eidx:].replace('.json', '.pt')) if not json_ else f
		metric = 'wer' if wer else 'entropy' if entropy else 'loss' if loss else 'per' if per else 'der' if der else 'cer'
		val = torch.tensor([j[metric] for j in json.load(open(f)) if j['labels_name'] == labels.name] or [0.0])
		val = val[~(torch.isnan(val) | torch.isinf(val))]

		if cer10 or cer20 or cer30 or cer40 or cer50:
			val = (val < 0.1 * [False, cer10, cer20, cer30, cer40, cer50].index(True)).float()
		if cer15:
			val = (val < 0.15).float()
		res[iteration].append((val_dataset_name, float(val.mean()), checkpoint))
	val_dataset_names = sorted(set(val_dataset_name for r in res.values() for val_dataset_name, cer, checkpoint in r))
	print('iteration\t' + '\t'.join(val_dataset_names))
	for iteration, r in res.items():
		cers = {val_dataset_name: f'{cer:.04f}' for val_dataset_name, cer, checkpoint in r}
		print(
			f'{iteration}\t' + '\t'.join(cers.get(val_dataset_name, '')
											for val_dataset_name in val_dataset_names) + f'\t{r[-1][-1]}'
		)
Пример #3
0
def normalize(input_path, lang, dry = True):
	lang = datasets.Language(lang)
	labels = datasets.Labels(lang)
	for transcript_path in input_path:
		with open(transcript_path) as f:
			transcript = json.load(f)
		for t in transcript:
			if 'ref' in t:
				t['ref'] = labels.postprocess_transcript(lang.normalize_text(t['ref']))
			if 'hyp' in t:
				t['hyp'] = labels.postprocess_transcript(lang.normalize_text(t['hyp']))

			if 'ref' in t and 'hyp' in t:
				t['cer'] = t['cer'] if 'cer' in t else metrics.cer(t['hyp'], t['ref'])
				t['wer'] = t['wer'] if 'wer' in t else metrics.wer(t['hyp'], t['ref'])

		if not dry:
			json.dump(transcript, open(transcript_path, 'w'), ensure_ascii = False, indent = 2, sort_keys = True)
		else:
			return transcript
Пример #4
0
def setup(args):
    torch.set_grad_enabled(False)
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = map(
        checkpoint['args'].get, [
            'sample_rate', 'window_size', 'window_stride', 'window',
            'num_input_features'
        ])
    frontend = models.LogFilterBankFrontend(args.num_input_features,
                                            args.sample_rate,
                                            args.window_size,
                                            args.window_stride,
                                            args.window,
                                            eps=1e-6)
    labels = datasets.Labels(datasets.Language(checkpoint['args']['lang']),
                             name='char')
    model = getattr(models, args.model or checkpoint['args']['model'])(
        args.num_input_features, [len(labels)],
        frontend=frontend,
        dict=lambda logits, log_probs, olen, **kwargs: (logits[0], olen[0]))
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model = model.to(args.device)
    model.eval()
    model.fuse_conv_bn_eval()
    if args.device != 'cpu':
        model, *_ = models.data_parallel_and_autocast(model,
                                                      opt_level=args.fp16)
    decoder = decoders.GreedyDecoder(
    ) if args.decoder == 'GreedyDecoder' else decoders.BeamSearchDecoder(
        labels,
        lm_path=args.lm,
        beam_width=args.beam_width,
        beam_alpha=args.beam_alpha,
        beam_beta=args.beam_beta,
        num_workers=args.num_workers,
        topk=args.decoder_topk)
    return labels, frontend, model, decoder
Пример #5
0
def lserrorwords(input_path, output_path, comment_path, freq_path, sortdesc, sortasc, comment_filter, lang):
	lang = datasets.Language(lang)
	regex = r'[ ]+-[ ]*', '-'
	freq = {
		splitted[0]: int(splitted[-1])
		for line in open(freq_path) for splitted in [re.sub(regex[0], regex[1], line).split()]
	} if freq_path else {}
	comment = {
		splitted[0]: splitted[-1].strip()
		for line in open(comment_path) for splitted in [line.split(',')] if '#' not in line and len(splitted) > 1
	} if comment_path else {}
	transcript = json.load(open(input_path))
	transcript = list(filter(lambda t: [(w.get('type') or w.get('error_tag')) for w in t['words']].count('missing_ref') <= 2, transcript))

	stem = lambda word: (lang.stem(word), len(word))
	words_ok = [w['ref'].replace(metrics.placeholder, '') for t in transcript for w in t['words'] if (w.get('type') or w.get('error_tag')) == 'ok']
	words_error = [
		w['ref'].replace(metrics.placeholder, '')
		for t in transcript
		for w in t['words']
		if (w.get('type') or w.get('error_tag')) not in ['ok', 'missing_ref']
	]
	words_error = set(ref for ref in words_error if len(ref) > 1)
	usage = {
		k: [tup[1] for tup in g]
		for k,
		g in itertools.groupby(
			sorted([(w['ref'].replace(metrics.placeholder, ''), t) for t in transcript for w in t['words']],
					key = lambda t: t[0]),
			key = lambda t: t[0]
		)
	}

	words_ok_counter = collections.Counter(map(stem, words_ok))
	words_error_counter = collections.Counter(map(stem, words_error))
	group = lambda c: stem(c[0])
	#comment = {k : ';'.join(set(c[1] for c in g if c[1])) for k, g in itertools.groupby(sorted(comment.items(), key = group), key = group)}

	#words = {ref : (ref, words_error_counter[l] - words_ok_counter[l], words_error_counter[l], words_ok_counter[l], freq.get(ref, 0), (usage.get(ref, []) + usage_placeholder)[0], (usage.get(ref, []) + usage_placeholder)[1], comment.get(ref, '')) for ref in words_error for l in [stem(ref)]}
	words = {
		ref: (
			ref,
			words_error_counter[l] - words_ok_counter[l],
			words_error_counter[l],
			words_ok_counter[l],
			freq.get(ref, 0),
			usage.get(ref, [{}])[0]['audio_name'],
			usage.get(ref, [{}])[0]['ref'],
			comment.get(ref, '')
		)
		for ref in words_error for l in [stem(ref)]
	}
	key = sortdesc or sortasc
	words = list(
		sorted(
			words.values(),
			key = lambda t: (t[-5] if key == 'diff' else (-t[2] - t[3], t[5]), t[0]),
			reverse = bool(sortdesc)
		)
	)
	words = filter(lambda tup: comment_filter in tup[-1], words)
	f = open(output_path, 'w')
	if output_path.endswith('.csv'):
		f.write('#word,diff,err,ok,freq,audioname,usage,comment\n' + '\n'.join(','.join(map(str, t)) for t in words))
	elif output_path.endswith('.json'):
		json.dump([
			dict(audio_name = audio_name, before = word, after = '') for word,
			diff,
			err,
			ok,
			freq,
			audio_name,
			usage,
			comment in words
		],
					f,
					ensure_ascii = False,
					indent = 2,
					sort_keys = True)

	print(output_path)
Пример #6
0
def logits(lang, logits, audio_name = None, MAX_ENTROPY = 1.0):
	good_audio_name = set(map(str.strip, open(audio_name[0])) if os.path.exists(audio_name[0]) else audio_name) if audio_name is not None else []
	labels = datasets.Labels(datasets.Language(lang))
	decoder = decoders.GreedyDecoder()
	tick_params = lambda ax, labelsize = 2.5, length = 0, **kwargs: ax.tick_params(axis = 'both', which = 'both', labelsize = labelsize, length = length, **kwargs) or [ax.set_linewidth(0) for ax in ax.spines.values()]
	logits_path = logits + '.html'
	html = open(logits_path, 'w')
	html.write('<html><head>' + meta_charset + f'</head><body><script>{play_script}{onclick_img_script}</script>')

	for i, t in enumerate(torch.load(logits)):
		audio_path, logits = t['audio_path'], t['logits']
		words = t.get('words', [t])
		y = t.get('y', torch.zeros(1, 0, dtype = torch.long))
		begin = t.get('begin', '')
		end = t.get('end', '')
		audio_name = transcripts.audio_name(audio_path)
		extra_metrics = dict(cer = t['cer']) if 'cer' in t else {}

		if good_audio_name and audio_name not in good_audio_name:
			continue

		log_probs = F.log_softmax(logits, dim = 0)
		entropy = models.entropy(log_probs, dim = 0, sum = False)
		log_probs_ = F.log_softmax(logits[:-1], dim = 0)
		entropy_ = models.entropy(log_probs_, dim = 0, sum = False)
		margin = models.margin(log_probs, dim = 0)
		#energy = features.exp().sum(dim = 0)[::2]

		plt.figure(figsize = (6, 2))
		ax = plt.subplot(211)
		plt.imshow(logits, aspect = 'auto')
		plt.xlim(0, logits.shape[-1] - 1)
		#plt.yticks([])
		plt.axis('off')
		tick_params(plt.gca())
		#plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95)

		plt.subplot(212, sharex = ax)
		prob_top1, prob_top2 = log_probs.exp().topk(2, dim = 0).values
		plt.hlines(1.0, 0, entropy.shape[-1] - 1, linewidth = 0.2)
		artist_prob_top1, = plt.plot(prob_top1, 'b', linewidth = 0.3)
		artist_prob_top2, = plt.plot(prob_top2, 'g', linewidth = 0.3)
		artist_entropy, = plt.plot(entropy, 'r', linewidth = 0.3)
		artist_entropy_, = plt.plot(entropy_, 'yellow', linewidth = 0.3)
		plt.legend([artist_entropy, artist_entropy_, artist_prob_top1, artist_prob_top2],
					['entropy', 'entropy, no blank', 'top1 prob', 'top2 prob'],
					loc = 1,
					fontsize = 'xx-small',
					frameon = False)
		
		for b, e, v in zip(*models.rle1d(entropy > MAX_ENTROPY)):
			if bool(v):
				plt.axvspan(int(b), int(e), color='red', alpha=0.2)

		plt.ylim(0, 3.0)
		plt.xlim(0, entropy.shape[-1] - 1)

		decoded = decoder.decode(log_probs.unsqueeze(0), K = 5)[0]
		xlabels = list(
			map(
				'\n'.join,
				zip(
					*[
						labels.decode(d, replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False)
						for d in decoded
					]
				)
			)
		)
		plt.xticks(torch.arange(entropy.shape[-1]), xlabels, fontfamily = 'monospace')
		tick_params(plt.gca())

		if y.numel() > 0:
			alignment = ctc.alignment(
				log_probs.unsqueeze(0).permute(2, 0, 1),
				y.unsqueeze(0).long(),
				torch.LongTensor([log_probs.shape[-1]]),
				torch.LongTensor([len(y)]),
				blank = len(log_probs) - 1
			).squeeze(0)
			
			ax = plt.gca().secondary_xaxis('top')
			ref, ref_ = labels.decode(y.tolist(), replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False), alignment
			ax.set_xticklabels(ref)
			ax.set_xticks(ref_)
			tick_params(ax, colors = 'red')

		#k = 0
		#for i, c in enumerate(ref + ' '):
		#	if c == ' ':
		#		plt.axvspan(ref_[k] - 1, ref_[i - 1] + 1, facecolor = 'gray', alpha = 0.2)
		#		k = i + 1

		plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95)

		buf = io.BytesIO()
		plt.savefig(buf, format = 'jpg', dpi = 600)
		plt.close()

		html.write(f'<h4>{audio_name}')
		html.write(' | '.join('{k}: {v:.02f}' for k, v in extra_metrics.items()))
		html.write('</h4>')
		html.write(fmt_alignment(words))
		html.write('<img data-begin="{begin}" data-end="{end}" data-channel="{channel}" onclick="onclick_img(event)" style="width:100%" src="data:image/jpeg;base64,{encoded}"></img>\n'.format(channel = i, begin = begin, end = end, encoded = base64.b64encode(buf.getvalue()).decode()))
		html.write(fmt_audio(audio_path = audio_path, channel = i))
		html.write('<hr/>')
	html.write('</body></html>')
	return logits_path
Пример #7
0
def main(args):
    checkpoints = [
        torch.load(checkpoint_path, map_location='cpu')
        for checkpoint_path in args.checkpoint
    ]
    checkpoint = (checkpoints + [{}])[0]
    if len(checkpoints) > 1:
        checkpoint['model_state_dict'] = {
            k: sum(c['model_state_dict'][k]
                   for c in checkpoints) / len(checkpoints)
            for k in checkpoint['model_state_dict']
        }

    if args.frontend_checkpoint:
        frontend_checkpoint = torch.load(args.frontend_checkpoint,
                                         map_location='cpu')
        frontend_extra_args = frontend_checkpoint['args']
        frontend_checkpoint = frontend_checkpoint['model']
    else:
        frontend_extra_args = None
        frontend_checkpoint = None

    args.experiment_id = args.experiment_id.format(
        model=args.model,
        frontend=args.frontend,
        train_batch_size=args.train_batch_size,
        optimizer=args.optimizer,
        lr=args.lr,
        weight_decay=args.weight_decay,
        time=time.strftime('%Y-%m-%d_%H-%M-%S'),
        experiment_name=args.experiment_name,
        bpe='bpe' if args.bpe else '',
        train_waveform_transform=
        f'aug{args.train_waveform_transform[0]}{args.train_waveform_transform_prob or ""}'
        if args.train_waveform_transform else '',
        train_feature_transform=
        f'aug{args.train_feature_transform[0]}{args.train_feature_transform_prob or ""}'
        if args.train_feature_transform else '').replace('e-0',
                                                         'e-').rstrip('_')
    if checkpoint and 'experiment_id' in checkpoint[
            'args'] and not args.experiment_name:
        args.experiment_id = checkpoint['args']['experiment_id']
    args.experiment_dir = args.experiment_dir.format(
        experiments_dir=args.experiments_dir, experiment_id=args.experiment_id)

    os.makedirs(args.experiment_dir, exist_ok=True)

    if args.log_json:
        args.log_json = os.path.join(args.experiment_dir, 'log.json')

    if checkpoint:
        args.lang, args.model, args.num_input_features, args.sample_rate, args.window, args.window_size, args.window_stride = map(
            checkpoint['args'].get, [
                'lang', 'model', 'num_input_features', 'sample_rate', 'window',
                'window_size', 'window_stride'
            ])
        utils.set_up_root_logger(os.path.join(args.experiment_dir, 'log.txt'),
                                 mode='a')
        logfile_sink = JsonlistSink(args.log_json, mode='a')
    else:
        utils.set_up_root_logger(os.path.join(args.experiment_dir, 'log.txt'),
                                 mode='w')
        logfile_sink = JsonlistSink(args.log_json, mode='w')

    _print = utils.get_root_logger_print()
    _print('\n', 'Arguments:', args)
    _print(
        f'"CUDA_VISIBLE_DEVICES={os.environ.get("CUDA_VISIBLE_DEVICES", default = "")}"'
    )
    _print(
        f'"CUDA_LAUNCH_BLOCKING={os.environ.get("CUDA_LAUNCH_BLOCKING", default="")}"'
    )
    _print('Experiment id:', args.experiment_id, '\n')
    if args.dry:
        return
    utils.set_random_seed(args.seed)
    if args.cudnn == 'benchmark':
        torch.backends.cudnn.benchmark = True

    lang = datasets.Language(args.lang)
    #TODO: , candidate_sep = datasets.Labels.candidate_sep
    normalize_text_config = json.load(open(
        args.normalize_text_config)) if os.path.exists(
            args.normalize_text_config) else {}
    labels = [
        datasets.Labels(
            lang, name='char', normalize_text_config=normalize_text_config)
    ] + [
        datasets.Labels(lang,
                        bpe=bpe,
                        name=f'bpe{i}',
                        normalize_text_config=normalize_text_config)
        for i, bpe in enumerate(args.bpe)
    ]
    frontend = getattr(models,
                       args.frontend)(out_channels=args.num_input_features,
                                      sample_rate=args.sample_rate,
                                      window_size=args.window_size,
                                      window_stride=args.window_stride,
                                      window=args.window,
                                      dither=args.dither,
                                      dither0=args.dither0,
                                      stft_mode='conv' if args.onnx else None,
                                      extra_args=frontend_extra_args)
    model = getattr(models, args.model)(
        num_input_features=args.num_input_features,
        num_classes=list(map(len, labels)),
        dropout=args.dropout,
        decoder_type='bpe' if args.bpe else None,
        frontend=frontend if args.onnx or args.frontend_in_model else None,
        **(dict(inplace=False,
                dict=lambda logits, log_probs, olen, **kwargs: logits[0])
           if args.onnx else {}))

    _print('Model capacity:', int(models.compute_capacity(model, scale=1e6)),
           'million parameters\n')

    if checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    if frontend_checkpoint:
        frontend_checkpoint = {
            'model.' + name: weight
            for name, weight in frontend_checkpoint.items()
        }  ##TODO remove after save checkpoint naming fix
        frontend.load_state_dict(frontend_checkpoint)

    if args.onnx:
        torch.set_grad_enabled(False)
        model.eval()
        model.to(args.device)
        model.fuse_conv_bn_eval()

        if args.fp16:
            model = models.InputOutputTypeCast(model.to(torch.float16),
                                               dtype=torch.float16)

        waveform_input = torch.rand(args.onnx_sample_batch_size,
                                    args.onnx_sample_time,
                                    device=args.device)
        logits = model(waveform_input)

        torch.onnx.export(model, (waveform_input, ),
                          args.onnx,
                          opset_version=args.onnx_opset,
                          export_params=args.onnx_export_params,
                          do_constant_folding=True,
                          input_names=['x'],
                          output_names=['logits'],
                          dynamic_axes=dict(x={
                              0: 'B',
                              1: 'T'
                          },
                                            logits={
                                                0: 'B',
                                                2: 't'
                                            }))
        onnxruntime_session = onnxruntime.InferenceSession(args.onnx)
        if args.verbose:
            onnxruntime.set_default_logger_severity(0)
        (logits_, ) = onnxruntime_session.run(
            None, dict(x=waveform_input.cpu().numpy()))
        assert torch.allclose(logits.cpu(),
                              torch.from_numpy(logits_),
                              rtol=1e-02,
                              atol=1e-03)

        #model_def = onnx.load(args.onnx)
        #import onnx.tools.net_drawer # import GetPydotGraph, GetOpNodeProducer
        #pydot_graph = GetPydotGraph(model_def.graph, name=model_def.graph.name, rankdir="TB", node_producer=GetOpNodeProducer("docstring", color="yellow", fillcolor="yellow", style="filled"))
        #pydot_graph.write_dot("pipeline_transpose2x.dot")
        #os.system('dot -O -Gdpi=300 -Tpng pipeline_transpose2x.dot')
        # add metadata to model
        return

    perf.init_default(loss=dict(K=50, max=1000),
                      memory_cuda_allocated=dict(K=50),
                      entropy=dict(K=4),
                      time_ms_iteration=dict(K=50, max=10_000),
                      lr=dict(K=50, max=1))

    val_config = json.load(open(args.val_config)) if os.path.exists(
        args.val_config) else {}
    word_tags = json.load(open(args.word_tags)) if os.path.exists(
        args.word_tags) else {}
    for word_tag, words in val_config.get('word_tags', {}).items():
        word_tags[word_tag] = word_tags.get(word_tag, []) + words
    vocab = set(map(str.strip, open(args.vocab))) if os.path.exists(
        args.vocab) else set()
    error_analyzer = metrics.ErrorAnalyzer(
        metrics.WordTagger(lang, vocab=vocab, word_tags=word_tags),
        metrics.ErrorTagger(), val_config.get('error_analyzer', {}))

    make_transform = lambda name_args, prob: None if not name_args else getattr(
        transforms, name_args[0])(*name_args[1:]) if prob is None else getattr(
            transforms, name_args[0])(prob, *name_args[1:]
                                      ) if prob > 0 else None
    val_frontend = models.AugmentationFrontend(
        frontend,
        waveform_transform=make_transform(args.val_waveform_transform,
                                          args.val_waveform_transform_prob),
        feature_transform=make_transform(args.val_feature_transform,
                                         args.val_feature_transform_prob))

    if args.val_waveform_transform_debug_dir:
        args.val_waveform_transform_debug_dir = os.path.join(
            args.val_waveform_transform_debug_dir,
            str(val_frontend.waveform_transform) if isinstance(
                val_frontend.waveform_transform, transforms.RandomCompose) else
            val_frontend.waveform_transform.__class__.__name__)
        os.makedirs(args.val_waveform_transform_debug_dir, exist_ok=True)

    val_data_loaders = {
        os.path.basename(val_data_path): torch.utils.data.DataLoader(
            val_dataset,
            num_workers=args.num_workers,
            collate_fn=val_dataset.collate_fn,
            pin_memory=True,
            shuffle=False,
            batch_size=args.val_batch_size,
            worker_init_fn=datasets.worker_init_fn,
            timeout=args.timeout if args.num_workers > 0 else 0)
        for val_data_path in args.val_data_path for val_dataset in [
            datasets.AudioTextDataset(
                val_data_path,
                labels,
                args.sample_rate,
                frontend=val_frontend if not args.frontend_in_model else None,
                waveform_transform_debug_dir=args.
                val_waveform_transform_debug_dir,
                min_duration=args.min_duration,
                time_padding_multiple=args.batch_time_padding_multiple,
                pop_meta=True,
                _print=_print)
        ]
    }
    decoder = [
        decoders.GreedyDecoder() if args.decoder == 'GreedyDecoder' else
        decoders.BeamSearchDecoder(labels[0],
                                   lm_path=args.lm,
                                   beam_width=args.beam_width,
                                   beam_alpha=args.beam_alpha,
                                   beam_beta=args.beam_beta,
                                   num_workers=args.num_workers,
                                   topk=args.decoder_topk)
    ] + [decoders.GreedyDecoder() for bpe in args.bpe]

    model.to(args.device)

    if not args.train_data_path:
        model.eval()
        if not args.adapt_bn:
            model.fuse_conv_bn_eval()
        if args.device != 'cpu':
            model, *_ = models.data_parallel_and_autocast(
                model,
                opt_level=args.fp16,
                keep_batchnorm_fp32=args.fp16_keep_batchnorm_fp32)
        evaluate_model(args, val_data_loaders, model, labels, decoder,
                       error_analyzer)
        return

    model.freeze(backbone=args.freeze_backbone,
                 decoder0=args.freeze_decoder,
                 frontend=args.freeze_frontend)

    train_frontend = models.AugmentationFrontend(
        frontend,
        waveform_transform=make_transform(args.train_waveform_transform,
                                          args.train_waveform_transform_prob),
        feature_transform=make_transform(args.train_feature_transform,
                                         args.train_feature_transform_prob))
    tic = time.time()
    train_dataset = datasets.AudioTextDataset(
        args.train_data_path,
        labels,
        args.sample_rate,
        frontend=train_frontend if not args.frontend_in_model else None,
        min_duration=args.min_duration,
        max_duration=args.max_duration,
        time_padding_multiple=args.batch_time_padding_multiple,
        bucket=lambda example: int(
            math.ceil(((example[0]['end'] - example[0]['begin']) / args.
                       window_stride + 1) / args.batch_time_padding_multiple)),
        pop_meta=True,
        _print=_print)

    _print('Time train dataset created:', time.time() - tic, 'sec')
    train_dataset_name = '_'.join(map(os.path.basename, args.train_data_path))
    tic = time.time()
    sampler = datasets.BucketingBatchSampler(
        train_dataset,
        batch_size=args.train_batch_size,
    )
    _print('Time train sampler created:', time.time() - tic, 'sec')

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        num_workers=args.num_workers,
        collate_fn=train_dataset.collate_fn,
        pin_memory=True,
        batch_sampler=sampler,
        worker_init_fn=datasets.worker_init_fn,
        timeout=args.timeout if args.num_workers > 0 else 0)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        nesterov=args.nesterov
    ) if args.optimizer == 'SGD' else torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        betas=args.betas,
        weight_decay=args.weight_decay
    ) if args.optimizer == 'AdamW' else optimizers.NovoGrad(
        model.parameters(),
        lr=args.lr,
        betas=args.betas,
        weight_decay=args.weight_decay
    ) if args.optimizer == 'NovoGrad' else apex.optimizers.FusedNovoGrad(
        model.parameters(),
        lr=args.lr,
        betas=args.betas,
        weight_decay=args.weight_decay
    ) if args.optimizer == 'FusedNovoGrad' else None

    if checkpoint and checkpoint['optimizer_state_dict'] is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if not args.skip_optimizer_reset:
            optimizers.reset_options(optimizer)

    scheduler = optimizers.MultiStepLR(
        optimizer, gamma=args.decay_gamma, milestones=args.decay_milestones
    ) if args.scheduler == 'MultiStepLR' else optimizers.PolynomialDecayLR(
        optimizer,
        power=args.decay_power,
        decay_steps=len(train_data_loader) * args.decay_epochs,
        end_lr=args.decay_lr
    ) if args.scheduler == 'PolynomialDecayLR' else optimizers.NoopLR(
        optimizer)
    epoch, iteration = 0, 0
    if checkpoint:
        epoch, iteration = checkpoint['epoch'], checkpoint['iteration']
        if args.train_data_path == checkpoint['args']['train_data_path']:
            sampler.load_state_dict(checkpoint['sampler_state_dict'])
            if args.iterations_per_epoch and iteration and iteration % args.iterations_per_epoch == 0:
                sampler.batch_idx = 0
                epoch += 1
        else:
            epoch += 1
    if args.iterations_per_epoch:
        epoch_skip_fraction = 1 - args.iterations_per_epoch / len(
            train_data_loader)
        assert epoch_skip_fraction < args.max_epoch_skip_fraction, \
         f'args.iterations_per_epoch must not skip more than {args.max_epoch_skip_fraction:.1%} of each epoch'

    if args.device != 'cpu':
        model, optimizer = models.data_parallel_and_autocast(
            model,
            optimizer,
            opt_level=args.fp16,
            keep_batchnorm_fp32=args.fp16_keep_batchnorm_fp32)
    if checkpoint and args.fp16 and checkpoint['amp_state_dict'] is not None:
        apex.amp.load_state_dict(checkpoint['amp_state_dict'])

    model.train()

    tensorboard_dir = os.path.join(args.experiment_dir, 'tensorboard')
    if checkpoint and args.experiment_name:
        tensorboard_dir_checkpoint = os.path.join(
            os.path.dirname(args.checkpoint[0]), 'tensorboard')
        if os.path.exists(tensorboard_dir_checkpoint
                          ) and not os.path.exists(tensorboard_dir):
            shutil.copytree(tensorboard_dir_checkpoint, tensorboard_dir)
    tensorboard = torch.utils.tensorboard.SummaryWriter(tensorboard_dir)
    tensorboard_sink = TensorboardSink(tensorboard)

    with open(os.path.join(args.experiment_dir, args.args), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, ensure_ascii=False, indent=2)

    with open(os.path.join(args.experiment_dir, args.dump_model_config),
              'w') as f:
        model_config = dict(
            init_params=models.master_module(model).init_params,
            model=repr(models.master_module(model)))
        json.dump(model_config,
                  f,
                  sort_keys=True,
                  ensure_ascii=False,
                  indent=2)

    tic, toc_fwd, toc_bwd = time.time(), time.time(), time.time()

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for epoch in range(epoch, args.epochs):
        sampler.shuffle(epoch + args.seed_sampler)
        time_epoch_start = time.time()
        for batch_idx, (meta, s, x, xlen, y,
                        ylen) in enumerate(train_data_loader,
                                           start=sampler.batch_idx):
            toc_data = time.time()
            if batch_idx == 0:
                time_ms_launch_data_loader = (toc_data - tic) * 1000
                _print('Time data loader launch @ ', epoch, ':',
                       time_ms_launch_data_loader / 1000, 'sec')

            lr = optimizer.param_groups[0]['lr']
            perf.update(dict(lr=lr))

            x, xlen, y, ylen = x.to(args.device, non_blocking=True), xlen.to(
                args.device, non_blocking=True), y.to(
                    args.device, non_blocking=True), ylen.to(args.device,
                                                             non_blocking=True)
            try:
                #TODO check nan values in tensors, they can break running_stats in bn
                log_probs, olen, loss = map(
                    model(x, xlen, y=y, ylen=ylen).get,
                    ['log_probs', 'olen', 'loss'])
                oom_handler.reset()
            except:
                if oom_handler.try_recover(model.parameters(), _print=_print):
                    continue
                else:
                    raise
            example_weights = ylen[:, 0]
            loss, loss_cur = (loss * example_weights).mean(
            ) / args.train_batch_accumulate_iterations, float(loss.mean())

            perf.update(dict(loss_BT_normalized=loss_cur))

            entropy = float(
                models.entropy(log_probs[0], olen[0], dim=1).mean())
            toc_fwd = time.time()
            #TODO: inf/nan still corrupts BN stats
            if not (torch.isinf(loss) or torch.isnan(loss)):
                if args.fp16:
                    with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if iteration % args.train_batch_accumulate_iterations == 0:
                    torch.nn.utils.clip_grad_norm_(
                        apex.amp.master_params(optimizer)
                        if args.fp16 else model.parameters(), args.max_norm)
                    optimizer.step()

                    if iteration > 0 and iteration % args.log_iteration_interval == 0:
                        perf.update(utils.compute_memory_stats(),
                                    prefix='performance')
                        tensorboard_sink.perf(perf.default(), iteration,
                                              train_dataset_name)
                        tensorboard_sink.weight_stats(
                            iteration, model, args.log_weight_distribution)
                        logfile_sink.perf(perf.default(), iteration,
                                          train_dataset_name)

                    optimizer.zero_grad()
                    scheduler.step(iteration)
                perf.update(dict(entropy=entropy))
            toc_bwd = time.time()

            time_ms_data, time_ms_fwd, time_ms_bwd, time_ms_model = map(
                lambda sec: sec * 1000, [
                    toc_data - tic, toc_fwd - toc_data, toc_bwd - toc_fwd,
                    toc_bwd - toc_data
                ])
            perf.update(dict(time_ms_data=time_ms_data,
                             time_ms_fwd=time_ms_fwd,
                             time_ms_bwd=time_ms_bwd,
                             time_ms_iteration=time_ms_data + time_ms_model),
                        prefix='performance')
            perf.update(dict(input_B=x.shape[0], input_T=x.shape[-1]),
                        prefix='performance')
            print_left = f'{args.experiment_id} | epoch: {epoch:02d} iter: [{batch_idx: >6d} / {len(train_data_loader)} {iteration: >6d}] {"x".join(map(str, x.shape))}'
            print_right = 'ent: <{avg_entropy:.2f}> loss: {cur_loss_BT_normalized:.2f} <{avg_loss_BT_normalized:.2f}> time: {performance_cur_time_ms_data:.2f}+{performance_cur_time_ms_fwd:4.0f}+{performance_cur_time_ms_bwd:4.0f} <{performance_avg_time_ms_iteration:.0f}> | lr: {cur_lr:.5f}'.format(
                **perf.default())
            _print(print_left, print_right)
            iteration += 1
            sampler.batch_idx += 1

            if iteration > 0 and (iteration % args.val_iteration_interval == 0
                                  or iteration == args.iterations):
                evaluate_model(args, val_data_loaders, model, labels, decoder,
                               error_analyzer, optimizer, sampler,
                               tensorboard_sink, logfile_sink, epoch,
                               iteration)

            if iteration and args.iterations and iteration >= args.iterations:
                return

            if args.iterations_per_epoch and iteration > 0 and iteration % args.iterations_per_epoch == 0:
                break

            tic = time.time()

        sampler.batch_idx = 0
        _print('Epoch time', (time.time() - time_epoch_start) / 60, 'minutes')
        if not args.skip_on_epoch_end_evaluation:
            evaluate_model(args, val_data_loaders, model, labels, decoder,
                           error_analyzer, optimizer, sampler,
                           tensorboard_sink, logfile_sink, epoch + 1,
                           iteration)
Пример #8
0
parser.add_argument('-B', type = int, default = 256)
parser.add_argument('-T', type = int, default = 5.12)
parser.add_argument('--profile-cuda', action = 'store_true')
parser.add_argument('--profile-pyprof', action = 'store_true')
parser.add_argument('--profile-autograd')
parser.add_argument('--data-parallel', action = 'store_true')
parser.add_argument('--backward', action = 'store_true')
args = parser.parse_args()

checkpoint = torch.load(args.checkpoint, map_location = 'cpu') if args.checkpoint else None
if checkpoint:
	args.model, args.lang, args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = map(checkpoint['args'].get, ['model', 'lang', 'sample_rate', 'window_size', 'window_stride', 'window', 'num_input_features'])

use_cuda = 'cuda' in args.device

labels = datasets.Labels(datasets.Language(args.lang))

if args.onnx:
	onnxruntime_session = onnxruntime.InferenceSession(args.onnx)
	model = lambda x: onnxruntime_session.run(None, dict(x = x))
	load_batch = lambda x: x.numpy()

else:
	frontend = models.LogFilterBankFrontend(
		args.num_input_features,
		args.sample_rate,
		args.window_size,
		args.window_stride,
		args.window,
		stft_mode = args.stft_mode
	) if args.frontend else None