Beispiel #1
0
def main(args, ext_json=['.json', '.json.gz']):
    utils.enable_jit_fusion()

    assert args.output_json or args.output_html or args.output_txt or args.output_csv, \
     'at least one of the output formats must be provided'
    os.makedirs(args.output_path, exist_ok=True)

    audio_data_paths = set(
        p for f in args.input_path
        for p in ([os.path.join(f, g)
                   for g in os.listdir(f)] if os.path.isdir(f) else [f])
        if os.path.isfile(p) and any(map(p.endswith, args.ext)))
    json_data_paths = set(
        p for p in args.input_path if any(map(p.endswith, ext_json))
        and not utils.strip_suffixes(p, ext_json) in audio_data_paths)

    data_paths = list(audio_data_paths | json_data_paths)

    exclude = set([
        os.path.splitext(basename)[0]
        for basename in os.listdir(args.output_path)
        if basename.endswith('.json')
    ]) if args.skip_processed else None

    data_paths = [
        path for path in data_paths
        if exclude is None or os.path.basename(path) not in exclude
    ]

    text_pipeline, frontend, model, generator = setup(args)
    val_dataset = datasets.AudioTextDataset(
        data_paths, [text_pipeline],
        args.sample_rate,
        frontend=frontend if not args.frontend_in_model else None,
        mono=args.mono,
        time_padding_multiple=args.batch_time_padding_multiple,
        audio_backend=args.audio_backend,
        exclude=exclude,
        max_duration=args.transcribe_first_n_sec,
        mode='batched_channels'
        if args.join_transcript else 'batched_transcript',
        string_array_encoding=args.dataset_string_array_encoding,
        debug_short_long_records_features_from_whole_normalized_signal=args.
        debug_short_long_records_features_from_whole_normalized_signal)
    print('Examples count: ', len(val_dataset))
    val_meta = val_dataset.pop_meta()
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=None,
        collate_fn=val_dataset.collate_fn,
        num_workers=args.num_workers)
    csv_sep = dict(tab='\t', comma=',')[args.csv_sep]
    csv_lines = []  # only used if args.output_csv is True

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader):
        print(f'Processing: {i}/{len(val_dataset)}')
        meta = [val_meta[t['example_id']] for t in meta]

        audio_path = meta[0]['audio_path']
        audio_name = transcripts.audio_name(audio_path)
        begin_end = [dict(begin=t['begin'], end=t['end']) for t in meta]
        begin = torch.tensor([t['begin'] for t in begin_end],
                             dtype=torch.float)
        end = torch.tensor([t['end'] for t in begin_end], dtype=torch.float)
        #TODO WARNING assumes frontend not in dataset
        if not args.frontend_in_model:
            print('\n' * 10 + 'WARNING\n' * 5)
            print(
                'transcribe.py assumes frontend in model, in other case time alignment was incorrect'
            )
            print('WARNING\n' * 5 + '\n')

        duration = x.shape[-1] / args.sample_rate
        channel = [t['channel'] for t in meta]
        speaker = [t['speaker'] for t in meta]
        speaker_name = [t['speaker_name'] for t in meta]

        if x.numel() == 0:
            print(f'Skipping empty [{audio_path}].')
            continue

        try:
            tic = time.time()
            y, ylen = y.to(args.device), ylen.to(args.device)
            log_probs, logits, olen = model(
                x.squeeze(1).to(args.device), xlen.to(args.device))

            print('Input:', audio_name)
            print('Input time steps:', log_probs.shape[-1],
                  '| target time steps:', y.shape[-1])
            print(
                'Time: audio {audio:.02f} sec | processing {processing:.02f} sec'
                .format(audio=sum(map(transcripts.compute_duration, meta)),
                        processing=time.time() - tic))

            ts: shaping.Bt = duration * torch.linspace(
                0, 1, steps=log_probs.shape[-1],
                device=log_probs.device).unsqueeze(0).expand(x.shape[0], -1)

            ref_segments = [[
                dict(channel=channel[i],
                     begin=begin_end[i]['begin'],
                     end=begin_end[i]['end'],
                     ref=text_pipeline.postprocess(
                         text_pipeline.preprocess(meta[i]['ref'])))
            ] for i in range(len(meta))]
            hyp_segments = [
                alternatives[0] for alternatives in generator.generate(
                    tokenizer=text_pipeline.tokenizer,
                    log_probs=log_probs,
                    begin=begin,
                    end=end,
                    output_lengths=olen,
                    time_stamps=ts,
                    segment_text_key='hyp',
                    segment_extra_info=[
                        dict(speaker=s, speaker_name=sn, channel=c)
                        for s, sn, c in zip(speaker, speaker_name, channel)
                    ])
            ]
            hyp_segments = [
                transcripts.map_text(text_pipeline.postprocess, hyp=hyp)
                for hyp in hyp_segments
            ]
            hyp, ref = '\n'.join(
                transcripts.join(hyp=h)
                for h in hyp_segments).strip(), '\n'.join(
                    transcripts.join(ref=r) for r in ref_segments).strip()
            if args.verbose:
                print('HYP:', hyp)
            print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref)))

            tic_alignment = time.time()
            if args.align and y.numel() > 0:
                alignment: shaping.BY = ctc.alignment(
                    log_probs.permute(2, 0, 1),
                    y[:, 0, :],  # assumed that 0 channel is char labels
                    olen,
                    ylen[:, 0],
                    blank=text_pipeline.tokenizer.eps_id,
                    pack_backpointers=args.pack_backpointers)
                aligned_ts: shaping.Bt = ts.gather(1, alignment)

                ref_segments = [
                    alternatives[0] for alternatives in generator.generate(
                        tokenizer=text_pipeline.tokenizer,
                        log_probs=torch.nn.functional.one_hot(
                            y[:,
                              0, :], num_classes=log_probs.shape[1]).permute(
                                  0, 2, 1),
                        begin=begin,
                        end=end,
                        output_lengths=ylen,
                        time_stamps=aligned_ts,
                        segment_text_key='ref',
                        segment_extra_info=[
                            dict(speaker=s, speaker_name=sn, channel=c)
                            for s, sn, c in zip(speaker, speaker_name, channel)
                        ])
                ]
                ref_segments = [
                    transcripts.map_text(text_pipeline.postprocess, hyp=ref)
                    for ref in ref_segments
                ]
            oom_handler.reset()
        except:
            if oom_handler.try_recover(model.parameters()):
                print(f'Skipping {i} / {len(val_dataset)}')
                continue
            else:
                raise

        print('Alignment time: {:.02f} sec'.format(time.time() -
                                                   tic_alignment))

        ref_transcript, hyp_transcript = [
            sorted(transcripts.flatten(segments), key=transcripts.sort_key)
            for segments in [ref_segments, hyp_segments]
        ]

        if args.max_segment_duration:
            if ref:
                ref_segments = list(
                    transcripts.segment_by_time(ref_transcript,
                                                args.max_segment_duration))
                hyp_segments = list(
                    transcripts.segment_by_ref(hyp_transcript, ref_segments))
            else:
                hyp_segments = list(
                    transcripts.segment_by_time(hyp_transcript,
                                                args.max_segment_duration))
                ref_segments = [[] for _ in hyp_segments]

        #### HACK for diarization
        elif args.ref_transcript_path and args.join_transcript:
            audio_name_hack = audio_name.split('.')[0]
            #TODO: normalize ref field
            ref_segments = [[t] for t in sorted(transcripts.load(
                os.path.join(args.ref_transcript_path, audio_name_hack +
                             '.json')),
                                                key=transcripts.sort_key)]
            hyp_segments = list(
                transcripts.segment_by_ref(hyp_transcript,
                                           ref_segments,
                                           set_speaker=True,
                                           soft=False))
        #### END OF HACK

        has_ref = bool(transcripts.join(ref=transcripts.flatten(ref_segments)))

        transcript = []
        for hyp_transcript, ref_transcript in zip(hyp_segments, ref_segments):
            hyp, ref = transcripts.join(hyp=hyp_transcript), transcripts.join(
                ref=ref_transcript)

            transcript.append(
                dict(audio_path=audio_path,
                     ref=ref,
                     hyp=hyp,
                     speaker_name=transcripts.speaker_name(ref=ref_transcript,
                                                           hyp=hyp_transcript),
                     words=metrics.align_words(
                         *metrics.align_strings(hyp=hyp, ref=ref))
                     if args.align_words else [],
                     words_ref=ref_transcript,
                     words_hyp=hyp_transcript,
                     **transcripts.summary(hyp_transcript),
                     **(dict(cer=metrics.cer(hyp=hyp, ref=ref))
                        if has_ref else {})))

        transcripts.collect_speaker_names(transcript,
                                          set_speaker_data=True,
                                          num_speakers=2)

        filtered_transcript = list(
            transcripts.prune(transcript,
                              align_boundary_words=args.align_boundary_words,
                              cer=args.prune_cer,
                              duration=args.prune_duration,
                              gap=args.prune_gap,
                              allowed_unk_count=args.prune_unk,
                              num_speakers=args.prune_num_speakers))

        print('Filtered segments:', len(filtered_transcript), 'out of',
              len(transcript))

        if args.output_json:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.json')
            print(transcripts.save(transcript_path, filtered_transcript))

        if args.output_html:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.html')
            print(
                vis.transcript(transcript_path, args.sample_rate, args.mono,
                               transcript, filtered_transcript))

        if args.output_txt:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.txt')
            with open(transcript_path, 'w') as f:
                f.write(' '.join(t['hyp'].strip()
                                 for t in filtered_transcript))
            print(transcript_path)

        if args.output_csv:
            assert len({t['audio_path'] for t in filtered_transcript}) == 1
            audio_path = filtered_transcript[0]['audio_path']
            hyp = ' '.join(t['hyp'].strip() for t in filtered_transcript)
            begin = min(t['begin'] for t in filtered_transcript)
            end = max(t['end'] for t in filtered_transcript)
            csv_lines.append(
                csv_sep.join([audio_path, hyp,
                              str(begin),
                              str(end)]))

        if args.logits:
            logits_file_path = os.path.join(args.output_path,
                                            audio_name + '.pt')
            if args.logits_crop:
                begin_end = [
                    dict(
                        zip(['begin', 'end'], [
                            t['begin'] + c / float(o) * (t['end'] - t['begin'])
                            for c in args.logits_crop
                        ])) for o, t in zip(olen, begin_end)
                ]
                logits_crop = [slice(*args.logits_crop) for o in olen]
            else:
                logits_crop = [slice(int(o)) for o in olen]

            # TODO: filter ref / hyp by channel?
            torch.save([
                dict(audio_path=audio_path,
                     logits=l[..., logits_crop[i]],
                     **begin_end[i],
                     ref=ref,
                     hyp=hyp) for i, l in enumerate(logits.cpu())
            ], logits_file_path)
            print(logits_file_path)

        print('Done: {:.02f} sec\n'.format(time.time() - tic))

    if args.output_csv:
        transcript_path = os.path.join(args.output_path, 'transcripts.csv')
        with open(transcript_path, 'w') as f:
            f.write('\n'.join(csv_lines))
        print(transcript_path)
Beispiel #2
0
def main(args):
    checkpoints = [
        torch.load(checkpoint_path, map_location='cpu')
        for checkpoint_path in args.checkpoint
    ]
    checkpoint = (checkpoints + [{}])[0]
    if len(checkpoints) > 1:
        checkpoint['model_state_dict'] = {
            k: sum(c['model_state_dict'][k]
                   for c in checkpoints) / len(checkpoints)
            for k in checkpoint['model_state_dict']
        }

    if args.frontend_checkpoint:
        frontend_checkpoint = torch.load(args.frontend_checkpoint,
                                         map_location='cpu')
        frontend_extra_args = frontend_checkpoint['args']
        frontend_checkpoint = frontend_checkpoint['model']
    else:
        frontend_extra_args = None
        frontend_checkpoint = None

    args.experiment_id = args.experiment_id.format(
        model=args.model,
        frontend=args.frontend,
        train_batch_size=args.train_batch_size,
        optimizer=args.optimizer,
        lr=args.lr,
        weight_decay=args.weight_decay,
        time=time.strftime('%Y-%m-%d_%H-%M-%S'),
        experiment_name=args.experiment_name,
        bpe='bpe' if args.bpe else '',
        train_waveform_transform=
        f'aug{args.train_waveform_transform[0]}{args.train_waveform_transform_prob or ""}'
        if args.train_waveform_transform else '',
        train_feature_transform=
        f'aug{args.train_feature_transform[0]}{args.train_feature_transform_prob or ""}'
        if args.train_feature_transform else '').replace('e-0',
                                                         'e-').rstrip('_')
    if checkpoint and 'experiment_id' in checkpoint[
            'args'] and not args.experiment_name:
        args.experiment_id = checkpoint['args']['experiment_id']
    args.experiment_dir = args.experiment_dir.format(
        experiments_dir=args.experiments_dir, experiment_id=args.experiment_id)

    os.makedirs(args.experiment_dir, exist_ok=True)

    if args.log_json:
        args.log_json = os.path.join(args.experiment_dir, 'log.json')

    if checkpoint:
        args.lang, args.model, args.num_input_features, args.sample_rate, args.window, args.window_size, args.window_stride = map(
            checkpoint['args'].get, [
                'lang', 'model', 'num_input_features', 'sample_rate', 'window',
                'window_size', 'window_stride'
            ])
        utils.set_up_root_logger(os.path.join(args.experiment_dir, 'log.txt'),
                                 mode='a')
        logfile_sink = JsonlistSink(args.log_json, mode='a')
    else:
        utils.set_up_root_logger(os.path.join(args.experiment_dir, 'log.txt'),
                                 mode='w')
        logfile_sink = JsonlistSink(args.log_json, mode='w')

    _print = utils.get_root_logger_print()
    _print('\n', 'Arguments:', args)
    _print(
        f'"CUDA_VISIBLE_DEVICES={os.environ.get("CUDA_VISIBLE_DEVICES", default = "")}"'
    )
    _print(
        f'"CUDA_LAUNCH_BLOCKING={os.environ.get("CUDA_LAUNCH_BLOCKING", default="")}"'
    )
    _print('Experiment id:', args.experiment_id, '\n')
    if args.dry:
        return
    utils.set_random_seed(args.seed)
    if args.cudnn == 'benchmark':
        torch.backends.cudnn.benchmark = True

    lang = datasets.Language(args.lang)
    #TODO: , candidate_sep = datasets.Labels.candidate_sep
    normalize_text_config = json.load(open(
        args.normalize_text_config)) if os.path.exists(
            args.normalize_text_config) else {}
    labels = [
        datasets.Labels(
            lang, name='char', normalize_text_config=normalize_text_config)
    ] + [
        datasets.Labels(lang,
                        bpe=bpe,
                        name=f'bpe{i}',
                        normalize_text_config=normalize_text_config)
        for i, bpe in enumerate(args.bpe)
    ]
    frontend = getattr(models,
                       args.frontend)(out_channels=args.num_input_features,
                                      sample_rate=args.sample_rate,
                                      window_size=args.window_size,
                                      window_stride=args.window_stride,
                                      window=args.window,
                                      dither=args.dither,
                                      dither0=args.dither0,
                                      stft_mode='conv' if args.onnx else None,
                                      extra_args=frontend_extra_args)
    model = getattr(models, args.model)(
        num_input_features=args.num_input_features,
        num_classes=list(map(len, labels)),
        dropout=args.dropout,
        decoder_type='bpe' if args.bpe else None,
        frontend=frontend if args.onnx or args.frontend_in_model else None,
        **(dict(inplace=False,
                dict=lambda logits, log_probs, olen, **kwargs: logits[0])
           if args.onnx else {}))

    _print('Model capacity:', int(models.compute_capacity(model, scale=1e6)),
           'million parameters\n')

    if checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    if frontend_checkpoint:
        frontend_checkpoint = {
            'model.' + name: weight
            for name, weight in frontend_checkpoint.items()
        }  ##TODO remove after save checkpoint naming fix
        frontend.load_state_dict(frontend_checkpoint)

    if args.onnx:
        torch.set_grad_enabled(False)
        model.eval()
        model.to(args.device)
        model.fuse_conv_bn_eval()

        if args.fp16:
            model = models.InputOutputTypeCast(model.to(torch.float16),
                                               dtype=torch.float16)

        waveform_input = torch.rand(args.onnx_sample_batch_size,
                                    args.onnx_sample_time,
                                    device=args.device)
        logits = model(waveform_input)

        torch.onnx.export(model, (waveform_input, ),
                          args.onnx,
                          opset_version=args.onnx_opset,
                          export_params=args.onnx_export_params,
                          do_constant_folding=True,
                          input_names=['x'],
                          output_names=['logits'],
                          dynamic_axes=dict(x={
                              0: 'B',
                              1: 'T'
                          },
                                            logits={
                                                0: 'B',
                                                2: 't'
                                            }))
        onnxruntime_session = onnxruntime.InferenceSession(args.onnx)
        if args.verbose:
            onnxruntime.set_default_logger_severity(0)
        (logits_, ) = onnxruntime_session.run(
            None, dict(x=waveform_input.cpu().numpy()))
        assert torch.allclose(logits.cpu(),
                              torch.from_numpy(logits_),
                              rtol=1e-02,
                              atol=1e-03)

        #model_def = onnx.load(args.onnx)
        #import onnx.tools.net_drawer # import GetPydotGraph, GetOpNodeProducer
        #pydot_graph = GetPydotGraph(model_def.graph, name=model_def.graph.name, rankdir="TB", node_producer=GetOpNodeProducer("docstring", color="yellow", fillcolor="yellow", style="filled"))
        #pydot_graph.write_dot("pipeline_transpose2x.dot")
        #os.system('dot -O -Gdpi=300 -Tpng pipeline_transpose2x.dot')
        # add metadata to model
        return

    perf.init_default(loss=dict(K=50, max=1000),
                      memory_cuda_allocated=dict(K=50),
                      entropy=dict(K=4),
                      time_ms_iteration=dict(K=50, max=10_000),
                      lr=dict(K=50, max=1))

    val_config = json.load(open(args.val_config)) if os.path.exists(
        args.val_config) else {}
    word_tags = json.load(open(args.word_tags)) if os.path.exists(
        args.word_tags) else {}
    for word_tag, words in val_config.get('word_tags', {}).items():
        word_tags[word_tag] = word_tags.get(word_tag, []) + words
    vocab = set(map(str.strip, open(args.vocab))) if os.path.exists(
        args.vocab) else set()
    error_analyzer = metrics.ErrorAnalyzer(
        metrics.WordTagger(lang, vocab=vocab, word_tags=word_tags),
        metrics.ErrorTagger(), val_config.get('error_analyzer', {}))

    make_transform = lambda name_args, prob: None if not name_args else getattr(
        transforms, name_args[0])(*name_args[1:]) if prob is None else getattr(
            transforms, name_args[0])(prob, *name_args[1:]
                                      ) if prob > 0 else None
    val_frontend = models.AugmentationFrontend(
        frontend,
        waveform_transform=make_transform(args.val_waveform_transform,
                                          args.val_waveform_transform_prob),
        feature_transform=make_transform(args.val_feature_transform,
                                         args.val_feature_transform_prob))

    if args.val_waveform_transform_debug_dir:
        args.val_waveform_transform_debug_dir = os.path.join(
            args.val_waveform_transform_debug_dir,
            str(val_frontend.waveform_transform) if isinstance(
                val_frontend.waveform_transform, transforms.RandomCompose) else
            val_frontend.waveform_transform.__class__.__name__)
        os.makedirs(args.val_waveform_transform_debug_dir, exist_ok=True)

    val_data_loaders = {
        os.path.basename(val_data_path): torch.utils.data.DataLoader(
            val_dataset,
            num_workers=args.num_workers,
            collate_fn=val_dataset.collate_fn,
            pin_memory=True,
            shuffle=False,
            batch_size=args.val_batch_size,
            worker_init_fn=datasets.worker_init_fn,
            timeout=args.timeout if args.num_workers > 0 else 0)
        for val_data_path in args.val_data_path for val_dataset in [
            datasets.AudioTextDataset(
                val_data_path,
                labels,
                args.sample_rate,
                frontend=val_frontend if not args.frontend_in_model else None,
                waveform_transform_debug_dir=args.
                val_waveform_transform_debug_dir,
                min_duration=args.min_duration,
                time_padding_multiple=args.batch_time_padding_multiple,
                pop_meta=True,
                _print=_print)
        ]
    }
    decoder = [
        decoders.GreedyDecoder() if args.decoder == 'GreedyDecoder' else
        decoders.BeamSearchDecoder(labels[0],
                                   lm_path=args.lm,
                                   beam_width=args.beam_width,
                                   beam_alpha=args.beam_alpha,
                                   beam_beta=args.beam_beta,
                                   num_workers=args.num_workers,
                                   topk=args.decoder_topk)
    ] + [decoders.GreedyDecoder() for bpe in args.bpe]

    model.to(args.device)

    if not args.train_data_path:
        model.eval()
        if not args.adapt_bn:
            model.fuse_conv_bn_eval()
        if args.device != 'cpu':
            model, *_ = models.data_parallel_and_autocast(
                model,
                opt_level=args.fp16,
                keep_batchnorm_fp32=args.fp16_keep_batchnorm_fp32)
        evaluate_model(args, val_data_loaders, model, labels, decoder,
                       error_analyzer)
        return

    model.freeze(backbone=args.freeze_backbone,
                 decoder0=args.freeze_decoder,
                 frontend=args.freeze_frontend)

    train_frontend = models.AugmentationFrontend(
        frontend,
        waveform_transform=make_transform(args.train_waveform_transform,
                                          args.train_waveform_transform_prob),
        feature_transform=make_transform(args.train_feature_transform,
                                         args.train_feature_transform_prob))
    tic = time.time()
    train_dataset = datasets.AudioTextDataset(
        args.train_data_path,
        labels,
        args.sample_rate,
        frontend=train_frontend if not args.frontend_in_model else None,
        min_duration=args.min_duration,
        max_duration=args.max_duration,
        time_padding_multiple=args.batch_time_padding_multiple,
        bucket=lambda example: int(
            math.ceil(((example[0]['end'] - example[0]['begin']) / args.
                       window_stride + 1) / args.batch_time_padding_multiple)),
        pop_meta=True,
        _print=_print)

    _print('Time train dataset created:', time.time() - tic, 'sec')
    train_dataset_name = '_'.join(map(os.path.basename, args.train_data_path))
    tic = time.time()
    sampler = datasets.BucketingBatchSampler(
        train_dataset,
        batch_size=args.train_batch_size,
    )
    _print('Time train sampler created:', time.time() - tic, 'sec')

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        num_workers=args.num_workers,
        collate_fn=train_dataset.collate_fn,
        pin_memory=True,
        batch_sampler=sampler,
        worker_init_fn=datasets.worker_init_fn,
        timeout=args.timeout if args.num_workers > 0 else 0)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        nesterov=args.nesterov
    ) if args.optimizer == 'SGD' else torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        betas=args.betas,
        weight_decay=args.weight_decay
    ) if args.optimizer == 'AdamW' else optimizers.NovoGrad(
        model.parameters(),
        lr=args.lr,
        betas=args.betas,
        weight_decay=args.weight_decay
    ) if args.optimizer == 'NovoGrad' else apex.optimizers.FusedNovoGrad(
        model.parameters(),
        lr=args.lr,
        betas=args.betas,
        weight_decay=args.weight_decay
    ) if args.optimizer == 'FusedNovoGrad' else None

    if checkpoint and checkpoint['optimizer_state_dict'] is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if not args.skip_optimizer_reset:
            optimizers.reset_options(optimizer)

    scheduler = optimizers.MultiStepLR(
        optimizer, gamma=args.decay_gamma, milestones=args.decay_milestones
    ) if args.scheduler == 'MultiStepLR' else optimizers.PolynomialDecayLR(
        optimizer,
        power=args.decay_power,
        decay_steps=len(train_data_loader) * args.decay_epochs,
        end_lr=args.decay_lr
    ) if args.scheduler == 'PolynomialDecayLR' else optimizers.NoopLR(
        optimizer)
    epoch, iteration = 0, 0
    if checkpoint:
        epoch, iteration = checkpoint['epoch'], checkpoint['iteration']
        if args.train_data_path == checkpoint['args']['train_data_path']:
            sampler.load_state_dict(checkpoint['sampler_state_dict'])
            if args.iterations_per_epoch and iteration and iteration % args.iterations_per_epoch == 0:
                sampler.batch_idx = 0
                epoch += 1
        else:
            epoch += 1
    if args.iterations_per_epoch:
        epoch_skip_fraction = 1 - args.iterations_per_epoch / len(
            train_data_loader)
        assert epoch_skip_fraction < args.max_epoch_skip_fraction, \
         f'args.iterations_per_epoch must not skip more than {args.max_epoch_skip_fraction:.1%} of each epoch'

    if args.device != 'cpu':
        model, optimizer = models.data_parallel_and_autocast(
            model,
            optimizer,
            opt_level=args.fp16,
            keep_batchnorm_fp32=args.fp16_keep_batchnorm_fp32)
    if checkpoint and args.fp16 and checkpoint['amp_state_dict'] is not None:
        apex.amp.load_state_dict(checkpoint['amp_state_dict'])

    model.train()

    tensorboard_dir = os.path.join(args.experiment_dir, 'tensorboard')
    if checkpoint and args.experiment_name:
        tensorboard_dir_checkpoint = os.path.join(
            os.path.dirname(args.checkpoint[0]), 'tensorboard')
        if os.path.exists(tensorboard_dir_checkpoint
                          ) and not os.path.exists(tensorboard_dir):
            shutil.copytree(tensorboard_dir_checkpoint, tensorboard_dir)
    tensorboard = torch.utils.tensorboard.SummaryWriter(tensorboard_dir)
    tensorboard_sink = TensorboardSink(tensorboard)

    with open(os.path.join(args.experiment_dir, args.args), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, ensure_ascii=False, indent=2)

    with open(os.path.join(args.experiment_dir, args.dump_model_config),
              'w') as f:
        model_config = dict(
            init_params=models.master_module(model).init_params,
            model=repr(models.master_module(model)))
        json.dump(model_config,
                  f,
                  sort_keys=True,
                  ensure_ascii=False,
                  indent=2)

    tic, toc_fwd, toc_bwd = time.time(), time.time(), time.time()

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for epoch in range(epoch, args.epochs):
        sampler.shuffle(epoch + args.seed_sampler)
        time_epoch_start = time.time()
        for batch_idx, (meta, s, x, xlen, y,
                        ylen) in enumerate(train_data_loader,
                                           start=sampler.batch_idx):
            toc_data = time.time()
            if batch_idx == 0:
                time_ms_launch_data_loader = (toc_data - tic) * 1000
                _print('Time data loader launch @ ', epoch, ':',
                       time_ms_launch_data_loader / 1000, 'sec')

            lr = optimizer.param_groups[0]['lr']
            perf.update(dict(lr=lr))

            x, xlen, y, ylen = x.to(args.device, non_blocking=True), xlen.to(
                args.device, non_blocking=True), y.to(
                    args.device, non_blocking=True), ylen.to(args.device,
                                                             non_blocking=True)
            try:
                #TODO check nan values in tensors, they can break running_stats in bn
                log_probs, olen, loss = map(
                    model(x, xlen, y=y, ylen=ylen).get,
                    ['log_probs', 'olen', 'loss'])
                oom_handler.reset()
            except:
                if oom_handler.try_recover(model.parameters(), _print=_print):
                    continue
                else:
                    raise
            example_weights = ylen[:, 0]
            loss, loss_cur = (loss * example_weights).mean(
            ) / args.train_batch_accumulate_iterations, float(loss.mean())

            perf.update(dict(loss_BT_normalized=loss_cur))

            entropy = float(
                models.entropy(log_probs[0], olen[0], dim=1).mean())
            toc_fwd = time.time()
            #TODO: inf/nan still corrupts BN stats
            if not (torch.isinf(loss) or torch.isnan(loss)):
                if args.fp16:
                    with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if iteration % args.train_batch_accumulate_iterations == 0:
                    torch.nn.utils.clip_grad_norm_(
                        apex.amp.master_params(optimizer)
                        if args.fp16 else model.parameters(), args.max_norm)
                    optimizer.step()

                    if iteration > 0 and iteration % args.log_iteration_interval == 0:
                        perf.update(utils.compute_memory_stats(),
                                    prefix='performance')
                        tensorboard_sink.perf(perf.default(), iteration,
                                              train_dataset_name)
                        tensorboard_sink.weight_stats(
                            iteration, model, args.log_weight_distribution)
                        logfile_sink.perf(perf.default(), iteration,
                                          train_dataset_name)

                    optimizer.zero_grad()
                    scheduler.step(iteration)
                perf.update(dict(entropy=entropy))
            toc_bwd = time.time()

            time_ms_data, time_ms_fwd, time_ms_bwd, time_ms_model = map(
                lambda sec: sec * 1000, [
                    toc_data - tic, toc_fwd - toc_data, toc_bwd - toc_fwd,
                    toc_bwd - toc_data
                ])
            perf.update(dict(time_ms_data=time_ms_data,
                             time_ms_fwd=time_ms_fwd,
                             time_ms_bwd=time_ms_bwd,
                             time_ms_iteration=time_ms_data + time_ms_model),
                        prefix='performance')
            perf.update(dict(input_B=x.shape[0], input_T=x.shape[-1]),
                        prefix='performance')
            print_left = f'{args.experiment_id} | epoch: {epoch:02d} iter: [{batch_idx: >6d} / {len(train_data_loader)} {iteration: >6d}] {"x".join(map(str, x.shape))}'
            print_right = 'ent: <{avg_entropy:.2f}> loss: {cur_loss_BT_normalized:.2f} <{avg_loss_BT_normalized:.2f}> time: {performance_cur_time_ms_data:.2f}+{performance_cur_time_ms_fwd:4.0f}+{performance_cur_time_ms_bwd:4.0f} <{performance_avg_time_ms_iteration:.0f}> | lr: {cur_lr:.5f}'.format(
                **perf.default())
            _print(print_left, print_right)
            iteration += 1
            sampler.batch_idx += 1

            if iteration > 0 and (iteration % args.val_iteration_interval == 0
                                  or iteration == args.iterations):
                evaluate_model(args, val_data_loaders, model, labels, decoder,
                               error_analyzer, optimizer, sampler,
                               tensorboard_sink, logfile_sink, epoch,
                               iteration)

            if iteration and args.iterations and iteration >= args.iterations:
                return

            if args.iterations_per_epoch and iteration > 0 and iteration % args.iterations_per_epoch == 0:
                break

            tic = time.time()

        sampler.batch_idx = 0
        _print('Epoch time', (time.time() - time_epoch_start) / 60, 'minutes')
        if not args.skip_on_epoch_end_evaluation:
            evaluate_model(args, val_data_loaders, model, labels, decoder,
                           error_analyzer, optimizer, sampler,
                           tensorboard_sink, logfile_sink, epoch + 1,
                           iteration)
Beispiel #3
0
def main(args):
    utils.enable_jit_fusion()

    assert args.output_json or args.output_html or args.output_txt or args.output_csv, \
     'at least one of the output formats must be provided'
    os.makedirs(args.output_path, exist_ok=True)
    data_paths = [
        p for f in args.input_path
        for p in ([os.path.join(f, g)
                   for g in os.listdir(f)] if os.path.isdir(f) else [f])
        if os.path.isfile(p) and any(map(p.endswith, args.ext))
    ] + [
        p
        for p in args.input_path if any(map(p.endswith, ['.json', '.json.gz']))
    ]
    exclude = set([
        os.path.splitext(basename)[0]
        for basename in os.listdir(args.output_path)
        if basename.endswith('.json')
    ] if args.skip_processed else [])
    data_paths = [
        path for path in data_paths if os.path.basename(path) not in exclude
    ]

    labels, frontend, model, decoder = setup(args)
    val_dataset = datasets.AudioTextDataset(
        data_paths, [labels],
        args.sample_rate,
        frontend=None,
        segmented=True,
        mono=args.mono,
        time_padding_multiple=args.batch_time_padding_multiple,
        audio_backend=args.audio_backend,
        exclude=exclude,
        max_duration=args.transcribe_first_n_sec,
        join_transcript=args.join_transcript,
        string_array_encoding=args.dataset_string_array_encoding)
    num_examples = len(val_dataset)
    print('Examples count: ', num_examples)
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=None,
        collate_fn=val_dataset.collate_fn,
        num_workers=args.num_workers)
    csv_sep = dict(tab='\t', comma=',')[args.csv_sep]
    output_lines = []  # only used if args.output_csv is True

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader):
        print(f'Processing: {i}/{num_examples}')

        meta = [val_dataset.meta.get(m['example_id']) for m in meta]
        audio_path = meta[0]['audio_path']

        if x.numel() == 0:
            print(f'Skipping empty [{audio_path}].')
            continue

        begin = meta[0]['begin']
        end = meta[0]['end']
        audio_name = transcripts.audio_name(audio_path)

        try:
            tic = time.time()
            y, ylen = y.to(args.device), ylen.to(args.device)
            log_probs, olen = model(
                x.squeeze(1).to(args.device), xlen.to(args.device))

            #speech = vad.detect_speech(x.squeeze(1), args.sample_rate, args.window_size, aggressiveness = args.vad, window_size_dilate = args.window_size_dilate)
            #speech = vad.upsample(speech, log_probs)
            #log_probs.masked_fill_(models.silence_space_mask(log_probs, speech, space_idx = labels.space_idx, blank_idx = labels.blank_idx), float('-inf'))

            decoded = decoder.decode(log_probs, olen)

            print('Input:', audio_name)
            print('Input time steps:', log_probs.shape[-1],
                  '| target time steps:', y.shape[-1])
            print(
                'Time: audio {audio:.02f} sec | processing {processing:.02f} sec'
                .format(audio=sum(
                    transcripts.compute_duration(t) for t in meta),
                        processing=time.time() - tic))

            ts = (x.shape[-1] / args.sample_rate) * torch.linspace(
                0, 1,
                steps=log_probs.shape[-1]).unsqueeze(0) + torch.FloatTensor(
                    [t['begin'] for t in meta]).unsqueeze(1)
            channel = [t['channel'] for t in meta]
            speaker = [t['speaker'] for t in meta]
            ref_segments = [[
                dict(channel=channel[i],
                     begin=meta[i]['begin'],
                     end=meta[i]['end'],
                     ref=labels.decode(y[i, 0, :ylen[i]].tolist()))
            ] for i in range(len(decoded))]
            hyp_segments = [
                labels.decode(decoded[i],
                              ts[i],
                              channel=channel[i],
                              replace_blank=True,
                              replace_blank_series=args.replace_blank_series,
                              replace_repeat=True,
                              replace_space=False,
                              speaker=speaker[i] if isinstance(
                                  speaker[i], str) else None)
                for i in range(len(decoded))
            ]

            ref, hyp = '\n'.join(
                transcripts.join(ref=r)
                for r in ref_segments).strip(), '\n'.join(
                    transcripts.join(hyp=h) for h in hyp_segments).strip()
            if args.verbose:
                print('HYP:', hyp)
            print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref)))

            tic_alignment = time.time()
            if args.align and y.numel() > 0:
                #if ref_full:# and not ref:
                #	#assert len(set(t['channel'] for t in meta)) == 1 or all(t['type'] != 'channel' for t in meta)
                #	#TODO: add space at the end
                #	channel = torch.ByteTensor(channel).repeat_interleave(log_probs.shape[-1]).reshape(1, -1)
                #	ts = ts.reshape(1, -1)
                #	log_probs = log_probs.transpose(0, 1).unsqueeze(0).flatten(start_dim = -2)
                #	olen = torch.tensor([log_probs.shape[-1]], device = log_probs.device, dtype = torch.long)
                #	y = y_full[None, None, :].to(y.device)
                #	ylen = torch.tensor([[y.shape[-1]]], device = log_probs.device, dtype = torch.long)
                #	segments = [([], sum([h for r, h in segments], []))]

                alignment = ctc.alignment(
                    log_probs.permute(2, 0, 1),
                    y.squeeze(1),
                    olen,
                    ylen.squeeze(1),
                    blank=labels.blank_idx,
                    pack_backpointers=args.pack_backpointers)
                ref_segments = [
                    labels.decode(y[i, 0, :ylen[i]].tolist(),
                                  ts[i],
                                  alignment[i],
                                  channel=channel[i],
                                  speaker=speaker[i],
                                  key='ref',
                                  speakers=val_dataset.speakers)
                    for i in range(len(decoded))
                ]
            oom_handler.reset()
        except:
            if oom_handler.try_recover(model.parameters()):
                print(f'Skipping {i} / {num_examples}')
                continue
            else:
                raise

        print('Alignment time: {:.02f} sec'.format(time.time() -
                                                   tic_alignment))

        if args.max_segment_duration:
            ref_transcript, hyp_transcript = [
                list(sorted(sum(segments, []), key=transcripts.sort_key))
                for segments in [ref_segments, hyp_segments]
            ]
            if ref:
                ref_segments = list(
                    transcripts.segment(ref_transcript,
                                        args.max_segment_duration))
                hyp_segments = list(
                    transcripts.segment(hyp_transcript, ref_segments))
            else:
                hyp_segments = list(
                    transcripts.segment(hyp_transcript,
                                        args.max_segment_duration))
                ref_segments = [[] for _ in hyp_segments]

        transcript = [
            dict(audio_path=audio_path,
                 ref=ref,
                 hyp=hyp,
                 speaker=transcripts.speaker(ref=ref_transcript,
                                             hyp=hyp_transcript),
                 cer=metrics.cer(hyp=hyp, ref=ref),
                 words=metrics.align_words(hyp=hyp, ref=ref)[-1]
                 if args.align_words else [],
                 alignment=dict(ref=ref_transcript, hyp=hyp_transcript),
                 **transcripts.summary(hyp_transcript)) for ref_transcript,
            hyp_transcript in zip(ref_segments, hyp_segments)
            for ref, hyp in [(transcripts.join(ref=ref_transcript),
                              transcripts.join(hyp=hyp_transcript))]
        ]
        filtered_transcript = list(
            transcripts.prune(transcript,
                              align_boundary_words=args.align_boundary_words,
                              cer=args.cer,
                              duration=args.duration,
                              gap=args.gap,
                              unk=args.unk,
                              num_speakers=args.num_speakers))

        print('Filtered segments:', len(filtered_transcript), 'out of',
              len(transcript))

        if args.output_json:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.json')
            print(transcript_path)
            with open(transcript_path, 'w') as f:
                json.dump(filtered_transcript,
                          f,
                          ensure_ascii=False,
                          sort_keys=True,
                          indent=2)

        if args.output_html:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.html')
            print(transcript_path)
            vis.transcript(transcript_path, args.sample_rate, args.mono,
                           transcript, filtered_transcript)

        if args.output_txt:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.txt')
            print(transcript_path)
            with open(transcript_path, 'w') as f:
                f.write(hyp)

        if args.output_csv:
            output_lines.append(
                csv_sep.join((audio_path, hyp, str(begin), str(end))) + '\n')

        print('Done: {:.02f} sec\n'.format(time.time() - tic))

    if args.output_csv:
        with open(os.path.join(args.output_path, 'transcripts.csv'), 'w') as f:
            f.writelines(output_lines)
Beispiel #4
0
def evaluate_model(args,
                   val_data_loaders,
                   model,
                   labels,
                   decoder,
                   error_analyzer,
                   optimizer=None,
                   sampler=None,
                   tensorboard_sink=None,
                   logfile_sink=None,
                   epoch=None,
                   iteration=None):
    _print = utils.get_root_logger_print()

    training = epoch is not None and iteration is not None
    columns = {}
    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for val_dataset_name, val_data_loader in val_data_loaders.items():
        _print(f'\n{val_dataset_name}@{iteration}')
        transcript, logits_, y_ = [], [], []
        analyze = args.analyze == [] or (args.analyze is not None
                                         and val_dataset_name in args.analyze)

        model.eval()
        if args.adapt_bn:
            models.reset_bn_running_stats_(model)
            for _ in apply_model(val_data_loader, model, labels, decoder,
                                 args.device, oom_handler):
                pass
        model.eval()
        cpu_list = lambda l: [[t.cpu() for t in t_] for t_ in l]

        tic = time.time()
        ref_, hyp_, audio_path_, loss_, entropy_ = [], [], [], [], []
        for batch_idx, (meta, loss, entropy, hyp, logits, y) in enumerate(
                apply_model(val_data_loader, model, labels, decoder,
                            args.device, oom_handler)):
            loss_.extend(loss.tolist())
            entropy_.extend(entropy.tolist())
            logits_.extend(
                zip(*cpu_list(logits)) if not training and args.logits else [])
            y_.extend(cpu_list(y))
            audio_path_.extend(m['audio_path'] for m in meta)
            ref_.extend(labels[0].normalize_text(m['ref']) for m in meta)
            hyp_.extend(zip(*hyp))
        toc_apply_model = time.time()
        time_sec_val_apply_model = toc_apply_model - tic
        perf.update(dict(time_sec_val_apply_model=time_sec_val_apply_model),
                    prefix=f'datasets_{val_dataset_name}')
        _print(f"Apply model {time_sec_val_apply_model:.1f} sec")

        analyze_args_gen = ((
            hyp,
            ref,
            analyze,
            dict(labels_name=label.name,
                 audio_path=audio_path,
                 audio_name=transcripts.audio_name(audio_path),
                 loss=loss,
                 entropy=entropy),
            label.postprocess_transcript,
        ) for ref, hyp_tuple, audio_path, loss, entropy in zip(
            ref_, hyp_, audio_path_, loss_, entropy_)
                            for label, hyp in zip(labels, hyp_tuple))

        if args.analyze_num_workers <= 0:
            transcript = [
                error_analyzer.analyze(*args) for args in analyze_args_gen
            ]
        else:
            with multiprocessing.pool.Pool(
                    processes=args.analyze_num_workers) as pool:
                transcript = pool.starmap(error_analyzer.analyze,
                                          analyze_args_gen)

        toc_analyze = time.time()
        time_sec_val_analyze = toc_analyze - toc_apply_model
        time_sec_val_total = toc_analyze - tic
        perf.update(dict(time_sec_val_analyze=time_sec_val_analyze,
                         time_sec_val_total=time_sec_val_total),
                    prefix=f'datasets_{val_dataset_name}')
        _print(
            f"Analyze {time_sec_val_analyze:.1f} sec, Total {time_sec_val_total:.1f} sec"
        )

        for i, t in enumerate(transcript if args.verbose else []):
            _print(
                f'{val_dataset_name}@{iteration}: {i // len(labels)} / {len(audio_path_)} | {args.experiment_id}'
            )
            # TODO: don't forget to fix aligned hyp & ref output!
            # hyp = new_transcript['alignment']['hyp'] if analyze else new_transcript['hyp']
            # ref = new_transcript['alignment']['ref'] if analyze else new_transcript['ref']
            _print('REF: {labels_name} "{ref}"'.format(**t))
            _print('HYP: {labels_name} "{hyp}"'.format(**t))
            _print('WER: {labels_name} {wer:.02%} | CER: {cer:.02%}\n'.format(
                **t))

        transcripts_path = os.path.join(
            args.experiment_dir,
            args.train_transcripts_format.format(
                val_dataset_name=val_dataset_name,
                epoch=epoch,
                iteration=iteration)
        ) if training else args.val_transcripts_format.format(
            val_dataset_name=val_dataset_name, decoder=args.decoder)
        for i, label in enumerate(labels):
            transcript_by_label = transcript[i::len(labels)]
            aggregated = error_analyzer.aggregate(transcript_by_label)
            if analyze:
                with open(f'{transcripts_path}.errors.csv', 'w') as f:
                    f.writelines('{hyp},{ref},{error_tag}\n'.format(**w)
                                 for w in aggregated['errors']['words'])

            _print('errors', aggregated['errors']['distribution'])
            cer = torch.FloatTensor([r['cer'] for r in transcript_by_label])
            loss = torch.FloatTensor([r['loss'] for r in transcript_by_label])
            _print('cer', metrics.quantiles(cer))
            _print('loss', metrics.quantiles(loss))
            _print(
                f'{args.experiment_id} {val_dataset_name} {label.name}',
                f'| epoch {epoch} iter {iteration}' if training else '',
                f'| {transcripts_path} |',
                ('Entropy: {entropy:.02f} Loss: {loss:.02f} | WER:  {wer:.02%} CER: {cer:.02%} [{words_easy_errors_easy__cer_pseudo:.02%}],  MER: {mer_wordwise:.02%} DER: {hyp_der:.02%}/{ref_der:.02%}\n'
                 ).format(**aggregated))
            #columns[val_dataset_name + '_' + labels_name] = {'cer' : aggregated['cer_avg'], '.wer' : aggregated['wer_avg'], '.loss' : aggregated['loss_avg'], '.entropy' : aggregated['entropy_avg'], '.cer_easy' : aggregated['cer_easy_avg'], '.cer_hard':  aggregated['cer_hard_avg'], '.cer_missing' : aggregated['cer_missing_avg'], 'E' : dict(value = aggregated['errors_distribution']), 'L' : dict(value = vis.histc_vega(loss, min = 0, max = 3, bins = 20), type = 'vega'), 'C' : dict(value = vis.histc_vega(cer, min = 0, max = 1, bins = 20), type = 'vega'), 'T' : dict(value = [('audio_name', 'cer', 'mer', 'alignment')] + [(r['audio_name'], r['cer'], r['mer'], vis.word_alignment(r['words'])) for r in sorted(r_, key = lambda r: r['mer'], reverse = True)] if analyze else [], type = 'table')}

            if training:
                perf.update(
                    dict(wer=aggregated['wer'],
                         cer=aggregated['cer'],
                         loss=aggregated['loss']),
                    prefix=f'datasets_val_{val_dataset_name}_{label.name}')
                tensorboard_sink.val_stats(iteration, val_dataset_name,
                                           label.name, perf.default())

        with open(transcripts_path, 'w') as f:
            json.dump(transcript,
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)
        if analyze:
            vis.errors([transcripts_path], audio=args.vis_errors_audio)

        # TODO: transcript got flattened make this code work:
        # if args.logits:
        # 	logits_file_path = args.logits.format(val_dataset_name = val_dataset_name)
        # 	torch.save(
        # 		list(
        # 			sorted([
        # 				dict(
        # 					**r_,
        # 					logits = l_ if not args.logits_topk else models.sparse_topk(l_, args.logits_topk, dim = 0),
        # 					y = y_,
        # 					ydecoded = labels_.decode(y_.tolist())
        # 				) for r,
        # 				l,
        # 				y in zip(transcript, logits_, y_) for labels_,
        # 				r_,
        # 				l_,
        # 				y_ in zip(labels, r, l, y)
        # 			],
        # 					key = lambda r: r['cer'],
        # 					reverse = True)
        # 		),
        # 		logits_file_path
        # 	)
        # 	_print('Logits saved:', logits_file_path)

    checkpoint_path = os.path.join(
        args.experiment_dir,
        args.checkpoint_format.format(epoch=epoch, iteration=iteration)
    ) if training and not args.checkpoint_skip else None
    #if args.exphtml:
    #	columns['checkpoint_path'] = checkpoint_path
    #	exphtml.expjson(args.exphtml, args.experiment_id, epoch = epoch, iteration = iteration, meta = vars(args), columns = columns, tag = 'train' if training else 'test', git_http = args.githttp)
    #	exphtml.exphtml(args.exphtml)

    if training and not args.checkpoint_skip:
        torch.save(
            dict(model_state_dict=models.master_module(model).state_dict(),
                 optimizer_state_dict=optimizer.state_dict()
                 if optimizer is not None else None,
                 amp_state_dict=apex.amp.state_dict() if args.fp16 else None,
                 sampler_state_dict=sampler.state_dict()
                 if sampler is not None else None,
                 epoch=epoch,
                 iteration=iteration,
                 args=vars(args),
                 time=time.time(),
                 labels=[(l.name, str(l)) for l in labels]), checkpoint_path)

    model.train()