Ejemplo n.º 1
0
def set_seed(seed=None):
    if seed is not None:
        logger.info(f"set random seed to {seed}")
        torch.manual_seed(seed)
        np.random.seed(seed)
        if args.use_cuda:
            torch.cuda.manual_seed(seed)
Ejemplo n.º 2
0
def prepare(argv):
    parser = argparse.ArgumentParser(description="Prepare dataset by importing from Kaldi recipe")
    parser.add_argument('--text-only', default=False, action='store_true', help="if you want to process text only when wavs are already stored")
    parser.add_argument('--rebuild', default=False, action='store_true', help="if you want to rebuild manifest only instead of the overall processing")
    parser.add_argument('target_dir', type=str, help="path to store the processed data")
    args = parser.parse_args(argv)

    assert args.target_dir is not None
    assert not (args.text_only and args.rebuild), "options --text-only and --rebuild cannot together. choose either of them."

    log_file = Path(args.target_dir, 'prepare.log').resolve()
    init_logger(log_file="prepare.log")

    target_path = Path(args.target_dir).resolve()
    logger.info(f"target data path : {target_path}")

    importer = KaldiAspireImporter(target_path)

    if args.rebuild:
        importer.rebuild("train")
        importer.rebuild("dev")
        importer.rebuild("test")
    elif args.text_only:
        importer.process_text_only("train")
        importer.process_text_only("dev")
        importer.process_text_only("test")
    else:
        importer.process("train")
        importer.process("dev")
        importer.process("test")

    logger.info("data preparation finished.")
Ejemplo n.º 3
0
 def get_transcripts(self, mode):
     texts_file = self.recipe_path.joinpath("data", mode, "text")
     logger.info(f"processing {str(texts_file)} file ...")
     manifest = dict()
     with smart_open(texts_file, "r") as f:
         with open(self.target_path.joinpath(f"{mode}_convert.txt"),
                   "w") as wf:
             for line in tqdm(f,
                              total=get_num_lines(texts_file),
                              ncols=params.NCOLS):
                 try:
                     uttid, text = line.strip().split(" ", 1)
                     managed_text = self.strip_text(text)
                     if len(managed_text) == 0:
                         continue
                     if text != managed_text:
                         wf.write(f"{uttid} 0: {text}\n")
                         wf.write(f"{uttid} 1: {managed_text}\n\n")
                 except:
                     continue
                 p = uttid.find('-')
                 if p != -1:
                     tar_path = self.target_path.joinpath(mode, uttid[:p])
                 else:
                     tar_path = self.target_path.joinpath(mode)
                 tar_path.mkdir(mode=0o755, parents=True, exist_ok=True)
                 txt_file = tar_path.joinpath(uttid + ".txt")
                 with open(str(txt_file), "w") as txt:
                     txt.write(managed_text + "\n")
                 manifest[uttid] = (str(txt_file), managed_text)
     return manifest
Ejemplo n.º 4
0
    def validate(self, data_loader):
        "validate with label error rate by the edit distance between hyps and refs"
        self.model.eval()
        with torch.no_grad():
            N, D = 0, 0
            t = tqdm(enumerate(data_loader),
                     total=len(data_loader),
                     desc="validating",
                     ncols=p.NCOLS)
            for i, (data) in t:
                hyps, refs = self.unit_validate(data)
                # calculate ler
                N += self.edit_distance(refs, hyps)
                D += sum(len(r) for r in refs)
                ler = N * 100. / D
                t.set_description(f"validating (LER: {ler:.2f} %)")
                t.refresh()
            logger.info(
                f"validating at epoch {self.epoch:03d}: LER {ler:.2f} %")

            title = f"validate"
            x = self.epoch - 1 + i / len(data_loader)
            if logger.visdom is not None:
                opts = {
                    'xlabel': 'epoch',
                    'ylabel': 'LER',
                }
                logger.visdom.add_point(title=title, x=x, y=ler, **opts)
            if logger.tensorboard is not None:
                logger.tensorboard.add_scalars(title, x, {
                    'LER': ler,
                })
Ejemplo n.º 5
0
    def split_wav(self, mode):
        import io
        import wave
        segments_file = self.recipe_path.joinpath("data", mode, "segments")
        logger.info(f"processing {str(segments_file)} file ...")
        segments = dict()
        with smart_open(segments_file, "r") as f:
            for line in tqdm(f,
                             total=get_num_lines(segments_file),
                             ncols=params.NCOLS):
                split = line.strip().split()
                uttid, wavid, start, end = split[0], split[1], float(
                    split[2]), float(split[3])
                if wavid in segments:
                    segments[wavid].append((uttid, start, end))
                else:
                    segments[wavid] = [(uttid, start, end)]

        wav_scp = self.recipe_path.joinpath("data", mode, "wav.scp")
        logger.info(f"processing {str(wav_scp)} file ...")
        manifest = dict()
        with smart_open(wav_scp, "r") as rf:
            for line in tqdm(rf,
                             total=get_num_lines(wav_scp),
                             ncols=params.NCOLS):
                wavid, cmd = line.strip().split(" ", 1)
                if not wavid in segments:
                    continue
                cmd = cmd.strip().rstrip(' |').split()
                if cmd[0] == 'sph2pipe':
                    cmd[0] = str(SPH2PIPE_PATH)
                p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.PIPE)
                fp = io.BytesIO(p.stdout)
                with wave.openfp(fp, "rb") as wav:
                    fr = wav.getframerate()
                    nf = wav.getnframes()
                    for uttid, start, end in segments[wavid]:
                        fs, fe = int(fr * start -
                                     SAMPLE_MARGIN), int(fr * end +
                                                         SAMPLE_MARGIN)
                        if fs < 0 or fe > nf:
                            continue
                        wav.rewind()
                        wav.setpos(fs)
                        signal = wav.readframes(fe - fs)
                        p = uttid.find('-')
                        if p != -1:
                            tar_path = self.target_path.joinpath(
                                mode, uttid[:p])
                        else:
                            tar_path = self.target_path.joinpath(mode)
                        tar_path.mkdir(mode=0o755, parents=True, exist_ok=True)
                        wav_file = tar_path.joinpath(uttid + ".wav")
                        with wave.open(str(wav_file), "wb") as wf:
                            wf.setparams(wav.getparams())
                            wf.writeframes(signal)
                        manifest[uttid] = (str(wav_file), fe - fs)
        return manifest
Ejemplo n.º 6
0
 def load(self, file_path):
     if isinstance(file_path, str):
         file_path = Path(file_path)
     if not file_path.exists():
         logger.error(f"no such file {file_path} exists")
         sys.exit(1)
     logger.info(f"loading the model from {file_path}")
     to_device = f"cuda:{torch.cuda.current_device()}" if self.use_cuda else "cpu"
     states = torch.load(file_path, map_location=to_device)
     self.model.load_state_dict(states["model"])
Ejemplo n.º 7
0
 def process_text_only(self, mode):
     import wave
     logger.info(f"processing text only from \"{mode}\" ...")
     wav_manifest = dict()
     for wav_file in self.target_path.joinpath(mode).rglob("*.wav"):
         uttid = wav_file.stem
         with wave.openfp(str(wav_file), "rb") as wav:
             samples = wav.getnframes()
         wav_manifest[uttid] = (str(wav_file), samples)
     txt_manifest = self.get_transcripts(mode)
     self.make_manifest(mode, wav_manifest, txt_manifest)
Ejemplo n.º 8
0
 def rebuild(self, mode):
     import wave
     logger.info(f"rebuilding \"{mode}\" ...")
     wav_manifest, txt_manifest = dict(), dict()
     for wav_file in self.target_path.joinpath(mode).rglob("*.wav"):
         uttid = wav_file.stem
         with wave.openfp(str(wav_file), "rb") as wav:
             samples = wav.getnframes()
         wav_manifest[uttid] = (str(wav_file), samples)
         txt_file = str(wav_file).replace('wav', 'txt')
         if Path(txt_file).exists():
             txt_manifest[uttid] = (str(txt_file), '-')
     self.make_manifest(mode, wav_manifest, txt_manifest)
Ejemplo n.º 9
0
 def make_manifest(self, mode, wav_manifest, txt_manifest):
     logger.info(f"generating manifest to \"{mode}.csv\" ...")
     min_len, max_len = 1e30, 0
     histo = [0] * 31
     total = 0
     with open(self.target_path.joinpath(f"{mode}.csv"), "w") as f:
         for k, v in tqdm(wav_manifest.items(), ncols=params.NCOLS):
             if not k in txt_manifest:
                 continue
             wav_file, samples = v
             txt_file, _ = txt_manifest[k]
             f.write(f"{k},{wav_file},{samples},{txt_file}\n")
             total += 1
             sec = float(samples) / params.SAMPLE_RATE
             if sec < min_len:
                 min_len = sec
             if sec > max_len:
                 max_len = sec
             if sec < 30.:
                 histo[int(np.ceil(sec))] += 1
     logger.info(f"total {total} entries listed in the manifest file.")
     cum_histo = np.cumsum(histo) / total * 100.
     logger.info(f"min: {min_len:.2f} sec  max: {max_len:.2f} sec")
     logger.info(f"<5 secs: {cum_histo[5]:.2f} %  "
                 f"<10 secs: {cum_histo[10]:.2f} %  "
                 f"<15 secs: {cum_histo[15]:.2f} %  "
                 f"<20 secs: {cum_histo[20]:.2f} %  "
                 f"<25 secs: {cum_histo[25]:.2f} %  "
                 f"<30 secs: {cum_histo[30]:.2f} %")
Ejemplo n.º 10
0
 def make_ctc_labels(self):
     # find *.phn files
     logger.info(f"finding *.phn files under {str(self.target_path)}")
     phn_files = [str(x) for x in self.target_path.rglob("*.phn")]
     # convert
     for phn_file in tqdm(phn_files, ncols=params.NCOLS):
         phns = np.loadtxt(phn_file, dtype="int", ndmin=1)
         # make ctc labelings by removing duplications
         ctcs = np.array([x for x in remove_duplicates(phns)])
         # write ctc file
         # blank labels will be inserted in warp-ctc loss module,
         # so here the target labels have not to contain the blanks interleaved
         ctc_file = phn_file.replace("phn", "ctc")
         np.savetxt(str(ctc_file), ctcs, "%d")
     count_priors(phn_files)
Ejemplo n.º 11
0
 def print_result(self, filename, ys_hat, words):
     logger.info(f"decoding wav file: {str(Path(filename).resolve())}")
     if self.verbose:
         labels = onehot2int(ys_hat)
         logger.info(
             f"labels: {' '.join([str(x) for x in labels.tolist()])}")
         rd = [x.item() for x in remove_duplicates(labels, blank=0)]
         logger.info(
             f"duplicated_removed: {' '.join([str(x) for x in rd])}")
         symbols = [self.decoder.labeler.idx2phone(x) for x in rd]
         logger.info(f"symbols: {' '.join(symbols)}")
     words = words.squeeze()
     text = ' '.join([self.decoder.labeler.idx2word(i) for i in words]) \
            if words.dim() else '<null output from decoder>'
     logger.info(f"decoded text: {text}")
Ejemplo n.º 12
0
 def test(self, data_loader):
     "test with word error rate by the edit distance between hyps and refs"
     self.model.eval()
     with torch.no_grad():
         N, D = 0, 0
         t = tqdm(enumerate(data_loader), total=len(data_loader), desc="testing", ncols=params.NCOLS)
         for i, (data) in t:
             hyps, refs = self.unit_test(data)
             # calculate wer
             N += self.edit_distance(refs, hyps)
             D += sum(len(r) for r in refs)
             wer = N * 100. / D
             t.set_description(f"testing (WER: {wer:.2f} %)")
             t.refresh()
         logger.info(f"testing at epoch {self.epoch:03d}: WER {wer:.2f} %")
Ejemplo n.º 13
0
    def __init__(self, model, use_cuda=False, continue_from=None, verbose=False,
                 *args, **kwargs):
        assert continue_from is not None
        self.use_cuda = use_cuda
        self.verbose = verbose

        # load from args
        self.model = model
        if self.use_cuda:
            logger.info("using cuda")
            self.model.cuda()

        self.load(continue_from)

        # prepare kaldi latgen decoder
        self.decoder = LatGenCTCDecoder()
Ejemplo n.º 14
0
 def save(self, file_path, **kwargs):
     Path(file_path).parent.mkdir(mode=0o755, parents=True, exist_ok=True)
     logger.info(f"saving the model to {file_path}")
     states = kwargs
     states["epoch"] = self.epoch
     if is_distributed():
         model_state_dict = self.model.state_dict()
         strip_prefix = 9 if self.fp16 else 7
         # remove "module.1." prefix from keys
         states["model"] = {
             k[strip_prefix:]: v
             for k, v in model_state_dict.items()
         }
     else:
         states["model"] = self.model.state_dict()
     states["optimizer"] = self.optimizer.state_dict()
     states["lr_scheduler"] = self.lr_scheduler.state_dict()
     torch.save(states, file_path)
Ejemplo n.º 15
0
    def get_alignments(self):
        import io
        import pipes
        import gzip

        exp_dir = self.recipe_path.joinpath("exp", "tri5a").resolve()
        models = exp_dir.glob("*.mdl")
        model = sorted(models, key=lambda x: x.stat().st_mtime)[-1]

        logger.info("processing alignment files ...")
        logger.info(f"using the trained kaldi model: {model}")
        manifest = dict()
        alis = [x for x in exp_dir.glob("ali.*.gz")]
        for ali in tqdm(alis, ncols=params.NCOLS):
            cmd = [
                str(Path(KALDI_PATH, "src", "bin", "ali-to-phones")),
                "--per-frame", f"{model}", f"ark:-", f"ark,f:-"
            ]
            with gzip.GzipFile(ali, "rb") as a:
                p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.PIPE, input=a.read())
                with io.BytesIO(p.stdout) as f:
                    while True:
                        # mkdir
                        try:
                            uttid = read_string(f)
                        except ValueError:
                            break
                        p = uttid.find('-')
                        if p != -1:
                            tar_path = self.target_path.joinpath(
                                mode, uttid[:p])
                        else:
                            tar_path = self.target_path.joinpath(mode)
                        tar_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
                        # store phn file
                        phn_file = tar_dir.joinpath(uttid + ".phn")
                        phones = read_vec_int(f)
                        np.savetxt(str(phn_file), phones, "%d")
                        # prepare manifest elements
                        num_frms = len(phones)
                        manifest[uttid] = (str(phn_file), num_frms)
        return manifest
Ejemplo n.º 16
0
 def count_priors(self, phn_files=None):
     # load labels.txt
     labels = dict()
     with open('asr/kaldi/graph/labels.txt', 'r') as f:
         for line in f:
             splits = line.strip().split()
             label = splits[0]
             labels[label] = splits[1]
     blank = labels['<blk>']
     if phn_files is None:
         # find *.phn files
         logger.info(f"finding *.phn files under {str(self.target_path)}")
         phn_files = [str(x) for x in self.target_path.rglob("*.phn")]
     # count
     counts = [0] * len(labels)
     for phn_file in tqdm(phn_files, ncols=params.NCOLS):
         phns = np.loadtxt(phn_file, dtype="int", ndmin=1)
         # count labels for priors
         for c in phns:
             counts[int(c)] += 1
         counts[int(blank)] += len(phns) + 1
     # write count file
     count_file = self.target_path.joinpath("priors_count.txt")
     np.savetxt(str(count_file), counts, "%d")
Ejemplo n.º 17
0
    def train_epoch(self, data_loader):
        self.model.train()
        meter_loss = tnt.meter.MovingAverageValueMeter(
            len(data_loader) // 100 + 1)
        #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
        #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True)

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
            logger.debug(
                f"current lr = {self.optimizer.param_groups[0]['lr']:.3e}")
        if is_distributed() and data_loader.sampler is not None:
            data_loader.sampler.set_epoch(self.epoch)

        ckpt_step = 0.1
        ckpts = iter(
            len(data_loader) * np.arange(ckpt_step, 1 + ckpt_step, ckpt_step))

        def plot_graphs(loss, data_iter=0, title="train", stats=False):
            #if self.lr_scheduler is not None:
            #    self.lr_scheduler.step()
            x = self.epoch + data_iter / len(data_loader)
            self.global_step = int(x / ckpt_step)
            if logger.visdom is not None:
                opts = {
                    'xlabel': 'epoch',
                    'ylabel': 'loss',
                }
                logger.visdom.add_point(title=title, x=x, y=loss, **opts)
            if logger.tensorboard is not None:
                #logger.tensorboard.add_graph(self.model, xs)
                #xs_img = tvu.make_grid(xs[0, 0], normalize=True, scale_each=True)
                #logger.tensorboard.add_image('xs', self.global_step, xs_img)
                #ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1), normalize=True, scale_each=True)
                #logger.tensorboard.add_image('ys_hat', self.global_step, ys_hat_img)
                logger.tensorboard.add_scalars(title, self.global_step, {
                    'loss': loss,
                })
                if stats:
                    for name, param in self.model.named_parameters():
                        logger.tensorboard.add_histogram(
                            name, self.global_step,
                            param.clone().cpu().data.numpy())

        self.train_loop_before_hook()
        ckpt = next(ckpts)
        t = tqdm(enumerate(data_loader),
                 total=len(data_loader),
                 desc="training",
                 ncols=p.NCOLS)
        for i, (data) in t:
            loss_value = self.unit_train(data)
            if loss_value is not None:
                meter_loss.add(loss_value)
            t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})")
            t.refresh()
            #self.meter_accuracy.add(ys_int, ys)
            #self.meter_confusion.add(ys_int, ys)
            if i > ckpt:
                plot_graphs(meter_loss.value()[0], i)
                if self.checkpoint:
                    logger.info(
                        f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: "
                        f"{meter_loss.value()[0]:5.3f}")
                    if not is_distributed() or (is_distributed()
                                                and dist.get_rank() == 0):
                        self.save(
                            self.__get_model_name(
                                f"epoch_{self.epoch:03d}_ckpt_{i:07d}"))
                    self.train_loop_checkpoint_hook()
                ckpt = next(ckpts)

        self.epoch += 1
        logger.info(f"epoch {self.epoch:03d}: "
                    f"training loss {meter_loss.value()[0]:5.3f} ")
        #f"training accuracy {meter_accuracy.value()[0]:6.3f}")
        if not is_distributed() or (is_distributed() and dist.get_rank() == 0):
            self.save(self.__get_model_name(f"epoch_{self.epoch:03d}"))
            self.__remove_ckpt_files(self.epoch - 1)
        plot_graphs(meter_loss.value()[0], stats=True)
        self.train_loop_after_hook()
Ejemplo n.º 18
0
    def train_epoch(self, data_loader):
        self.model.train()
        meter_loss = tnt.meter.MovingAverageValueMeter(
            len(data_loader) // 100 + 1)

        #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
        #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True)

        def plot_scalar(i, loss, title="train"):
            #if self.lr_scheduler is not None:
            #    self.lr_scheduler.step()
            x = self.epoch + i / len(data_loader)
            if logger.visdom is not None:
                opts = {
                    'xlabel': 'epoch',
                    'ylabel': 'loss',
                }
                logger.visdom.add_point(title=title, x=x, y=loss, **opts)
            if logger.tensorboard is not None:
                logger.tensorboard.add_graph(self.model, xs)
                xs_img = tvu.make_grid(xs[0, 0],
                                       normalize=True,
                                       scale_each=True)
                logger.tensorboard.add_image('xs', x, xs_img)
                ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1),
                                           normalize=True,
                                           scale_each=True)
                logger.tensorboard.add_image('ys_hat', x, ys_hat_img)
                logger.tensorboard.add_scalars(title, x, {
                    'loss': loss,
                })

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        logger.debug(
            f"current lr = {self.optimizer.param_groups[0]['lr']:.3e}")
        if is_distributed() and data_loader.sampler is not None:
            data_loader.sampler.set_epoch(self.epoch)
        ckpts = iter(len(data_loader) * np.arange(0.1, 1.1, 0.1))
        ckpt = next(ckpts)
        self.train_loop_before_hook()
        # count the number of supervised batches seen in this epoch
        t = tqdm(enumerate(data_loader),
                 total=len(data_loader),
                 desc="training",
                 ncols=p.NCOLS)
        for i, (data) in t:
            loss_value = self.unit_train(data)
            if loss_value is not None:
                meter_loss.add(loss_value)
            t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})")
            t.refresh()
            #self.meter_accuracy.add(ys_int, ys)
            #self.meter_confusion.add(ys_int, ys)
            if i > ckpt:
                plot_scalar(i, meter_loss.value()[0])
                if self.checkpoint:
                    logger.info(
                        f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: "
                        f"{meter_loss.value()[0]:5.3f}")
                    if not is_distributed() or (is_distributed()
                                                and dist.get_rank() == 0):
                        self.save(
                            self.__get_model_name(
                                f"epoch_{self.epoch:03d}_ckpt_{i:07d}"))
                ckpt = next(ckpts)
            #input("press key to continue")

        plot_scalar(i, meter_loss.value()[0])
        self.epoch += 1
        logger.info(f"epoch {self.epoch:03d}: "
                    f"training loss {meter_loss.value()[0]:5.3f} ")
        #f"training accuracy {meter_accuracy.value()[0]:6.3f}")
        if not is_distributed() or (is_distributed() and dist.get_rank() == 0):
            self.save(self.__get_model_name(f"epoch_{self.epoch:03d}"))
            self.__remove_ckpt_files(self.epoch - 1)
        self.train_loop_after_hook()
Ejemplo n.º 19
0
    def train_epoch(self, data_loader):
        self.model.train()
        num_ckpt = int(np.ceil(len(data_loader) / 10))
        meter_loss = tnt.meter.MovingAverageValueMeter(
            len(data_loader) // 100 + 1)
        #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
        #meter_confusion = tnt.meter.ConfusionMeter(p.NUM_CTC_LABELS, normalized=True)
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
            logger.debug(f"current lr = {self.lr_scheduler.get_lr()}")
        if is_distributed() and data_loader.sampler is not None:
            data_loader.sampler.set_epoch(self.epoch)

        # count the number of supervised batches seen in this epoch
        t = tqdm(enumerate(data_loader),
                 total=len(data_loader),
                 desc="training")
        for i, (data) in t:
            loss_value = self.unit_train(data)
            meter_loss.add(loss_value)
            t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})")
            t.refresh()
            #self.meter_accuracy.add(ys_int, ys)
            #self.meter_confusion.add(ys_int, ys)

            if 0 < i < len(data_loader) and i % num_ckpt == 0:
                if not is_distributed() or (is_distributed()
                                            and dist.get_rank() == 0):
                    title = "train"
                    x = self.epoch + i / len(data_loader)
                    if logger.visdom is not None:
                        logger.visdom.add_point(title=title,
                                                x=x,
                                                y=meter_loss.value()[0])
                    if logger.tensorboard is not None:
                        logger.tensorboard.add_graph(self.model, xs)
                        xs_img = tvu.make_grid(xs[0, 0],
                                               normalize=True,
                                               scale_each=True)
                        logger.tensorboard.add_image('xs', x, xs_img)
                        ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1),
                                                   normalize=True,
                                                   scale_each=True)
                        logger.tensorboard.add_image('ys_hat', x, ys_hat_img)
                        logger.tensorboard.add_scalars(
                            title, x, {
                                'loss': meter_loss.value()[0],
                            })
                if self.checkpoint:
                    logger.info(
                        f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: "
                        f"{meter_loss.value()[0]:5.3f}")
                    if not is_distributed() or (is_distributed()
                                                and dist.get_rank() == 0):
                        self.save(
                            self.__get_model_name(
                                f"epoch_{self.epoch:03d}_ckpt_{i:07d}"))
            #input("press key to continue")

        self.epoch += 1
        logger.info(f"epoch {self.epoch:03d}: "
                    f"training loss {meter_loss.value()[0]:5.3f} ")
        #f"training accuracy {meter_accuracy.value()[0]:6.3f}")
        if not is_distributed() or (is_distributed() and dist.get_rank() == 0):
            self.save(self.__get_model_name(f"epoch_{self.epoch:03d}"))
            self.__remove_ckpt_files(self.epoch - 1)
Ejemplo n.º 20
0
 def process(self, mode):
     logger.info(f"processing \"{mode}\" ...")
     wav_manifest = self.split_wav(mode)
     txt_manifest = self.get_transcripts(mode)
     self.make_manifest(mode, wav_manifest, txt_manifest)