Esempio n. 1
0
def _compute_mmi_loss_exact_non_optimized(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    See :func:`_compute_mmi_loss_exact_optimized` for the meaning
    of the arguments.

    It's more readable, though it invokes k2.intersect_dense twice.

    Note:
      It uses less memory at the cost of speed. It is slower.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=10.0)

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Esempio n. 2
0
    def __init__(
        self,
        lang_dir: Pathlike,
        scripted_model_path: Optional[Pathlike] = None,
        model_dir: Optional[Pathlike] = None,
        average_epochs: Sequence[int] = (7, 8, 9),
        device: torch.device = 'cpu',
        sampling_rate: int = 16000,
    ):
        if isinstance(device, str):
            self.device = torch.device(device)

        self.sampling_rate = sampling_rate
        self.extractor = Fbank(FbankConfig(num_mel_bins=80))
        self.lexicon = Lexicon(lang_dir)
        phone_ids = self.lexicon.phone_symbols()
        self.P = create_bigram_phone_lm(phone_ids)

        if model_dir is not None:
            # Read model from regular checkpoints, assume it's a Conformer
            self.model = Conformer(num_features=80,
                                   num_classes=len(phone_ids) + 1,
                                   num_decoder_layers=0)
            self.P.scores = torch.zeros_like(self.P.scores)
            self.model.P_scores = torch.nn.Parameter(self.P.scores.clone(),
                                                     requires_grad=False)
            average_checkpoint(filenames=[
                model_dir / f'epoch-{n}.pt' for n in average_epochs
            ],
                               model=self.model)
        elif scripted_model_path is not None:
            # Read model from a serialized TorchScript module, no assumptions needed
            self.model = torch.jit.load(scripted_model_path)
        else:
            raise ValueError(
                "One of scripted_model_path or model_dir needs to be provided."
            )

        # Freeze the params by default.
        for p in self.model.parameters():
            p.requires_grad_(False)
        self.compiler = MmiTrainingGraphCompiler(lexicon=self.lexicon,
                                                 device=self.device)
        self.HLG = k2.Fsa.from_dict(torch.load(lang_dir / 'HLG.pt')).to(
            self.device)
Esempio n. 3
0
def _compute_mmi_loss_pruned(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        P: k2.Fsa,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    See :func:`_compute_mmi_loss_exact_optimized` for the meaning
    of the arguments.

    `pruned` means it uses k2.intersect_dense_pruned

    Note:
      It uses the least amount of memory, but the loss is not exact due
      to pruning.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts,
                                                    P,
                                                    replicate_den=False)
    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)

    # the values for search_beam/output_beam/min_active_states/max_active_states
    # are not tuned. You may want to tune them.
    den_lats = k2.intersect_dense_pruned(den_graphs,
                                         dense_fsa_vec,
                                         search_beam=20.0,
                                         output_beam=7.0,
                                         min_active_states=30,
                                         max_active_states=10000)

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Esempio n. 4
0
def _compute_mmi_loss_exact_optimized(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        P: k2.Fsa,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    The function name contains `exact`, which means it uses a version of
    intersection without pruning.

    `optimized` in the function name means this function is optimized
    in that it calls k2.intersect_dense only once

    Note:
      It is faster at the cost of using more memory.

    Args:
      nnet_output:
        A 3-D tensor of shape [N, T, C]
      texts:
        The transcript. Each element consists of space(s) separated words.
      supervision_segments:
        A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`.
      graph_compiler:
        Used to build num_graphs and den_graphs
      P:
        Represents a bigram Fsa.
      den_scale:
        The scale applied to the denominator tot_scores.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts,
                                                    P,
                                                    replicate_den=False)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    device = num_graphs.device

    num_fsas = num_graphs.shape[0]
    assert dense_fsa_vec.dim0() == num_fsas

    assert den_graphs.shape[0] == 1

    # the aux_labels of num_graphs is k2.RaggedInt
    # but it is torch.Tensor for den_graphs.
    #
    # The following converts den_graphs.aux_labels
    # from torch.Tensor to k2.RaggedInt so that
    # we can use k2.append() later
    den_graphs.convert_attr_to_ragged_(name='aux_labels')

    # The motivation to concatenate num_graphs and den_graphs
    # is to reduce the number of calls to k2.intersect_dense.
    num_den_graphs = k2.cat([num_graphs, den_graphs])

    # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
    # so the following reorders num_den_graphs.
    #
    # The following code computes a_to_b_map

    # [0, 1, 2, ... ]
    num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

    # [num_fsas, num_fsas, num_fsas, ... ]
    den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)

    # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
    num_den_graphs_indexes = torch.stack(
        [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)

    num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)

    # [[0, 1, 2, ...]]
    a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

    # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
    a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

    num_den_lats = k2.intersect_dense(num_den_reordered_graphs,
                                      dense_fsa_vec,
                                      output_beam=10.0,
                                      a_to_b_map=a_to_b_map)

    num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True,
                                                     use_double_scores=True)

    num_tot_scores = num_den_tot_scores[::2]
    den_tot_scores = num_den_tot_scores[1::2]

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Esempio n. 5
0
class ASR:
    """
    This class is a high-level wrapper for K2 acoustic models that simplifies inference:
    reading models, computing posteriors, decoding, alignments, etc.

    Currently it will only work with the Conformer model with a very specific HMM topology.
    It could be the basis for a more generic entry point to Snow(Ice?)fall.
    """
    def __init__(
        self,
        lang_dir: Pathlike,
        scripted_model_path: Optional[Pathlike] = None,
        model_dir: Optional[Pathlike] = None,
        average_epochs: Sequence[int] = (7, 8, 9),
        device: torch.device = 'cpu',
        sampling_rate: int = 16000,
    ):
        if isinstance(device, str):
            self.device = torch.device(device)

        self.sampling_rate = sampling_rate
        self.extractor = Fbank(FbankConfig(num_mel_bins=80))
        self.lexicon = Lexicon(lang_dir)
        phone_ids = self.lexicon.phone_symbols()
        self.P = create_bigram_phone_lm(phone_ids)

        if model_dir is not None:
            # Read model from regular checkpoints, assume it's a Conformer
            self.model = Conformer(num_features=80,
                                   num_classes=len(phone_ids) + 1,
                                   num_decoder_layers=0)
            self.P.scores = torch.zeros_like(self.P.scores)
            self.model.P_scores = torch.nn.Parameter(self.P.scores.clone(),
                                                     requires_grad=False)
            average_checkpoint(filenames=[
                model_dir / f'epoch-{n}.pt' for n in average_epochs
            ],
                               model=self.model)
        elif scripted_model_path is not None:
            # Read model from a serialized TorchScript module, no assumptions needed
            self.model = torch.jit.load(scripted_model_path)
        else:
            raise ValueError(
                "One of scripted_model_path or model_dir needs to be provided."
            )

        # Freeze the params by default.
        for p in self.model.parameters():
            p.requires_grad_(False)
        self.compiler = MmiTrainingGraphCompiler(lexicon=self.lexicon,
                                                 device=self.device)
        self.HLG = k2.Fsa.from_dict(torch.load(lang_dir / 'HLG.pt')).to(
            self.device)

    def compute_features(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor:
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[
            0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'
        otf = OnTheFlyFeatures(self.extractor)
        # feats: (batch, seq_len, n_feats)
        feats, _ = otf(cuts)
        return feats

    def compute_posteriors(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor:
        """
        Run the forward pass of the acoustic model and return a tensor representing a batch of phone posteriorgrams.
        """
        # Extract feats
        # (batch, seq_len, num_feats)
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[
            0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'
        otf = OnTheFlyFeatures(self.extractor)
        # feats: (batch, seq_len, n_feats)
        feats, _ = otf(cuts)
        # feats: (batch, n_feats, seq_len)
        feats = feats.permute(0, 2, 1)

        # Compute AM posteriors
        # posteriors: (batch, n_phones, ~seq_len / 4)
        posteriors, _, _ = self.model(feats)
        # returns: (batch, ~seq_len / 4, n_phones)
        return posteriors.permute(0, 2, 1)

    def decode(
            self, cuts: Union[AnyCut,
                              CutSet]) -> List[Tuple[List[str], List[str]]]:
        """
        Perform decoding with an n-gram language model (HLG graph).
        Doesn't support rescoring at this time.
        """
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        word_results = []
        # Hacky way to get batch quickly... we may need to improve on this.
        batch = K2SpeechRecognitionDataset(cuts,
                                           input_strategy=OnTheFlyFeatures(
                                               self.extractor),
                                           check_inputs=False)[list(cuts.ids)]
        features = batch['inputs'].permute(0, 2, 1).to(
            self.device)  # (B, T, F) -> (B, F, T)
        supervision_segments, texts = encode_supervisions(
            batch['supervisions'])

        # Forward pass through the acoustic model
        posteriors, _, _ = self.model(features)
        posteriors = posteriors.permute(0, 2, 1)  # (B, F, T) -> (B, T, F)

        # Wrapping into k2 "dense FSA" (representing PPG as a dense graph)
        dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments)

        # The actual decoding starts here:
        # First, we intersect the HLG and the PPG
        # with default pruning/beam search params from snowfall
        # The result is a batch of graphs (lattices)
        lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8,
                                             30, 10000)
        # ... then we find the shortest paths in the lattices ...
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        # ... and convert them to words with a convenience wrapper from snowfall
        hyps = get_texts(best_paths, torch.arange(len(texts)))

        # Here we read out the words from the best path graphs
        for i in range(len(texts)):
            hyp_words = [self.lexicon.words.get(x) for x in hyps[i]]
            ref_words = texts[i].split(' ')
            word_results.append((ref_words, hyp_words))
        return word_results

    def align(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor:
        """
        Perform forced alignment and return a tensor that represents a batch of frame-level alignments:
        >>> alignments = torch.tensor([
        ...     [0, 0, 0, 1, 57, 57, 35, 35, 35, ...],
        ...     [...],
        ...     ...
        ... ])

        :return: an int32 tensor with shape ``(batch_size, num_frames)``.
        """
        # Extract feats
        # (batch, seq_len, num_feats)
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[
            0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'

        cuts = cuts.map_supervisions(self.normalize_text)

        otf = OnTheFlyFeatures(self.extractor)
        feats, _ = otf(cuts)
        feats = feats.permute(0, 2, 1)
        texts = [' '.join(s.text for s in cut.supervisions) for cut in cuts]

        # Compute AM posteriors
        # (batch, seq_len ~/ 4, num_phones)
        posteriors, _, _ = self.model(feats)
        # Note: we are using "dummy" supervisions so that the aligner also considers
        # the padding area. We can adjust that behaviour if needed by passing actual
        # supervision segments, but then we will have a ragged tensor (will need to
        # pad the alignments themselves).
        sups = self.dummy_supervisions(feats)
        posteriors_fsa = k2.DenseFsaVec(posteriors.permute(0, 2, 1), sups)

        # Intersection with ground truth transcript graphs
        num, den = self.compiler.compile(texts, self.P)
        alignment = k2.intersect_dense(num, posteriors_fsa, output_beam=10.0)
        best_path = k2.shortest_path(alignment, use_double_scores=True)

        # Retrieve sequences of phone IDs per frame
        # (batch, seq_len ~/ 4) -- dtype int32 (num phone labels)
        frame_labels = torch.stack(
            [best_path[i].labels[:-1] for i in range(best_path.shape[0])])
        return frame_labels

    def align_ctm(self, cuts: Union[CutSet,
                                    AnyCut]) -> List[List[AlignmentItem]]:
        """
        Perform forced alignment and parse the phones into a CTM-like format:
            >>> [[0.0, 0.12, 'SIL'], [0.12, 0.2, 'AH0'], ...]
        """
        # TODO: I am not sure that this method is extracting the alignment 100% correctly:
        #       need to revise...
        # TODO: when K2/Snowfall has a standard way of indicating what is silence,
        #       or we update the model, update the constants below.
        EPS = 0
        SIL = 1
        non_speech = {EPS, SIL}

        def to_s(n: int) -> float:
            FRAME_SHIFT = 0.04  # 0.01 * 4 subsampling
            return round(n * FRAME_SHIFT, ndigits=3)

        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])

        # Uppercase and remove punctuation
        cuts = cuts.map_supervisions(self.normalize_text)
        alignments = self.align(cuts).tolist()

        ctm_alis = []
        for cut, alignment in zip(cuts, alignments):
            # First we determine the silence regions at the beginning and the end:
            # we assume that every SIL and <eps> before the first phone, and after the last phone,
            # are representing silence.
            first_speech_idx = [
                idx for idx, s in enumerate(alignment) if s not in non_speech
            ][0]
            last_speech_idx = [
                idx for idx, s in reversed(list(enumerate(alignment)))
                if s not in non_speech
            ][0]
            speech_ali = alignment[first_speech_idx:last_speech_idx]
            ctm_ali = [
                AlignmentItem(start=0.0,
                              duration=to_s(first_speech_idx),
                              symbol=self.lexicon.phones[SIL])
            ]

            # Then, we iterate over the speech region: since the K2 model uses 2-state HMM
            # topology that allows blank (<eps>) to follow a phone symbol, we treat <eps>
            # as continuation of the "previous" phone.
            # TODO: I think this implementation is wrong in that it merges repeating phones...
            #       Will fix.
            # TODO: I think it could be simplified by using some smart semi-ring and FSA operations...
            start = first_speech_idx
            prev_s = speech_ali[0]
            curr_s = speech_ali[0]
            cntr = 1
            for s in speech_ali[1:]:
                curr_s = s if s != EPS else curr_s
                if curr_s != prev_s:
                    ctm_ali.append(
                        AlignmentItem(start=to_s(start),
                                      duration=to_s(cntr),
                                      symbol=self.lexicon.phones[prev_s]))
                    start = start + cntr
                    prev_s = curr_s
                    cntr = 1
                else:
                    cntr += 1
            if cntr:
                ctm_ali.append(
                    AlignmentItem(start=to_s(start),
                                  duration=to_s(cntr),
                                  symbol=self.lexicon.phones[prev_s]))

            speech_end_timestamp = to_s(last_speech_idx)
            if speech_end_timestamp > cut.duration:
                logging.warning(
                    f"speech_end_timestamp <= cut.duration. Skipping cut {cut.id}"
                )
                ctm_alis.append(None)
                continue

            ctm_ali.append(
                AlignmentItem(start=speech_end_timestamp,
                              duration=round(cut.duration -
                                             speech_end_timestamp,
                                             ndigits=8),
                              symbol=self.lexicon.phones[SIL]))
            ctm_alis.append(ctm_ali)

        return ctm_alis

    def plot_alignments(self, cut: AnyCut):
        import matplotlib.pyplot as plt
        feats = self.compute_features(cut)
        phone_ids = self.align(cut)
        fig, axes = plt.subplots(2,
                                 squeeze=True,
                                 sharey=True,
                                 figsize=(10, 14))
        axes[0].imshow(np.flipud(feats[0].T))
        axes[1].imshow(
            torch.nn.functional.one_hot(
                phone_ids.repeat_interleave(4).to(torch.int64)).T)
        return fig, axes

    def plot_posteriors(self, cut: AnyCut):
        import matplotlib.pyplot as plt
        feats = self.compute_features(cut)
        posteriors = self.compute_posteriors(cut)
        fig, axes = plt.subplots(2,
                                 squeeze=True,
                                 sharey=True,
                                 figsize=(10, 14))
        axes[0].imshow(np.flipud(feats[0].T))
        axes[1].imshow(posteriors[0].exp().repeat_interleave(4, 1))
        return fig, axes

    @staticmethod
    def dummy_supervisions(feats):
        def size_after_conv(size, num_layers=2):
            for i in range(num_layers):
                size = (size - 1) // 2
            return size

        return torch.tensor([[
            i,
            size_after_conv(2, num_layers=2),
            size_after_conv(feats.shape[2] - 2, num_layers=2)
        ] for i in range(feats.size(0))],
                            dtype=torch.int32).clamp(min=0)

    @staticmethod
    def normalize_text(supervision):
        text = re.sub(r'[^\w\s]', '', supervision.text.upper())
        return fastcopy(supervision, text=text)
def run(rank, world_size, args):
    '''
    Args:
      rank:
        It is a value between 0 and `world_size-1`, which is
        passed automatically by `mp.spawn()` in :func:`main`.
        The node with rank 0 is responsible for saving checkpoint.
      world_size:
        Number of GPUs for DDP training.
      args:
        The return value of get_parser().parse_args()
    '''
    model_type = args.model_type
    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    accum_grad = args.accum_grad
    den_scale = args.den_scale
    att_rate = args.att_rate

    fix_random_seed(42)
    setup_dist(rank, world_size, args.master_port)

    exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
    setup_logger(f'{exp_dir}/log/log-train-{rank}')
    if args.tensorboard and rank == 0:
        tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
    else:
        tb_writer = None
    #  tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None

    logging.info("Loading lexicon and symbol tables")
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = rank
    device = torch.device('cuda', device_id)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        device=device,
    )
    phone_ids = lexicon.phone_symbols()
    P = create_bigram_phone_lm(phone_ids)
    P.scores = torch.zeros_like(P.scores)
    P = P.to(device)

    mls = MLSAsrDataModule(args)
    train_dl = mls.train_dataloaders()
    valid_dl = mls.valid_dataloaders()

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "conformer":
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "contextnet":
        model = ContextNet(num_features=80, num_classes=len(phone_ids) +
                           1)  # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) +
                                  " is not implemented")

    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    model.to(device)
    describe(model)

    model = DDP(model, device_ids=[rank])

    # Now for the aligment model, if any
    if args.use_ali_model:
        ali_model = TdnnLstm1b(
            num_features=80,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4)

        ali_model_fname = Path(
            f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt')
        assert ali_model_fname.is_file(), \
                f'ali model filename {ali_model_fname} does not exist!'
        ali_model.load_state_dict(
            torch.load(ali_model_fname, map_location='cpu')['state_dict'])
        ali_model.to(device)

        ali_model.eval()
        ali_model.requires_grad_(False)
        logging.info(f'Use ali_model: {ali_model_fname}')
    else:
        ali_model = None
        logging.info('No ali_model')

    optimizer = Noam(model.parameters(),
                     model_size=args.attention_dim,
                     factor=args.lr_factor,
                     warm_step=args.warm_step,
                     weight_decay=args.weight_decay)

    scaler = GradScaler(enabled=args.amp)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer,
                               scaler=scaler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_dl.sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            ali_model=ali_model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
            world_size=world_size,
            scaler=scaler)
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            scaler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train,
                            local_rank=rank)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch,
                               local_rank=rank)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        scaler=scaler,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train,
                        local_rank=rank)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch,
                           local_rank=rank)

    logging.warning('Done')
    torch.distributed.barrier()
    cleanup_dist()
def run(rank, world_size, args):
    '''
    Args:
      rank:
        It is a value between 0 and `world_size-1`, which is
        passed automatically by `mp.spawn()` in :func:`main`.
        The node with rank 0 is responsible for saving checkpoint.
      world_size:
        Number of GPUs for DDP training.
      args:
        The return value of get_parser().parse_args()
    '''
    model_type = args.model_type
    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    accum_grad = args.accum_grad
    den_scale = args.den_scale
    att_rate = args.att_rate
    use_pruned_intersect = args.use_pruned_intersect

    fix_random_seed(42)
    if world_size > 1:
        setup_dist(rank, world_size, args.master_port)

    suffix = ''
    if args.context_window is not None and args.context_window > 0:
        suffix = f'ac{args.context_window}'
    giga_subset = f'giga{args.subset}'
    exp_dir = Path(
        f'exp-{model_type}-mmi-att-sa-vgg-normlayer-{giga_subset}-{suffix}')

    setup_logger(f'{exp_dir}/log/log-train-{rank}')
    if args.tensorboard and rank == 0:
        tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
    else:
        tb_writer = None

    logging.info("Loading lexicon and symbol tables")
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = rank
    device = torch.device('cuda', device_id)

    if not Path(lang_dir / f'P_{args.subset}.pt').is_file():
        logging.debug(f'Loading P from {lang_dir}/P_{args.subset}.fst.txt')
        with open(lang_dir / f'P_{args.subset}.fst.txt') as f:
            # P is not an acceptor because there is
            # a back-off state, whose incoming arcs
            # have label #0 and aux_label eps.
            P = k2.Fsa.from_openfst(f.read(), acceptor=False)

        phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)

        # P.aux_labels is not needed in later computations, so
        # remove it here.
        del P.aux_labels
        # CAUTION(fangjun): The following line is crucial.
        # Arcs entering the back-off state have label equal to #0.
        # We have to change it to 0 here.
        P.labels[P.labels >= first_phone_disambig_id] = 0

        P = k2.remove_epsilon(P)
        P = k2.arc_sort(P)
        torch.save(P.as_dict(), lang_dir / f'P_{args.subset}.pt')
    else:
        logging.debug('Loading pre-compiled P')
        d = torch.load(lang_dir / f'P_{args.subset}.pt')
        P = k2.Fsa.from_dict(d)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        P=P,
        device=device,
    )
    phone_ids = lexicon.phone_symbols()

    gigaspeech = GigaSpeechAsrDataModule(args)
    train_dl = gigaspeech.train_dataloaders()
    valid_dl = gigaspeech.valid_dataloaders()

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    if use_pruned_intersect:
        logging.info('Use pruned intersect for den_lats')
    else:
        logging.info("Don't use pruned intersect for den_lats")

    logging.info("About to create model")

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "conformer":
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True,
            is_espnet_structure=True)
    elif model_type == "contextnet":
        model = ContextNet(num_features=80, num_classes=len(phone_ids) +
                           1)  # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) +
                                  " is not implemented")

    if args.torchscript:
        logging.info('Applying TorchScript to model...')
        model = torch.jit.script(model)

    model.to(device)
    describe(model)

    if world_size > 1:
        model = DDP(model, device_ids=[rank])

    # Now for the alignment model, if any
    if args.use_ali_model:
        ali_model = TdnnLstm1b(
            num_features=80,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4)

        ali_model_fname = Path(
            f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt')
        assert ali_model_fname.is_file(), \
                f'ali model filename {ali_model_fname} does not exist!'
        ali_model.load_state_dict(
            torch.load(ali_model_fname, map_location='cpu')['state_dict'])
        ali_model.to(device)

        ali_model.eval()
        ali_model.requires_grad_(False)
        logging.info(f'Use ali_model: {ali_model_fname}')
    else:
        ali_model = None
        logging.info('No ali_model')

    optimizer = Noam(model.parameters(),
                     model_size=args.attention_dim,
                     factor=args.lr_factor,
                     warm_step=args.warm_step,
                     weight_decay=args.weight_decay)

    scaler = GradScaler(enabled=args.amp)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer,
                               scaler=scaler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_dl.sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            ali_model=ali_model,
            device=device,
            graph_compiler=graph_compiler,
            use_pruned_intersect=use_pruned_intersect,
            optimizer=optimizer,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
            world_size=world_size,
            scaler=scaler)
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            scaler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train,
                            local_rank=rank,
                            torchscript=args.torchscript_epoch != -1
                            and epoch >= args.torchscript_epoch)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch,
                               local_rank=rank)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        scaler=scaler,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train,
                        local_rank=rank,
                        torchscript=args.torchscript_epoch != -1
                        and epoch >= args.torchscript_epoch)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch,
                           local_rank=rank)

    logging.warning('Done')
    if world_size > 1:
        torch.distributed.barrier()
        cleanup_dist()
Esempio n. 8
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    subsampling_factor = model.module.subsampling_factor if isinstance(
        model, DDP) else model.subsampling_factor
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'], subsampling_factor),
         torch.floor_divide(supervisions['num_frames'], subsampling_factor)),
        1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num, den = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num, den = graph_compiler.compile(texts, P)

    assert num.requires_grad == is_training
    assert den.requires_grad is False
    num = num.to(device)
    den = den.to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
    den = k2.intersect_dense(den, dense_fsa_vec, 10.0)

    num_tot_scores = num.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    den_tot_scores = den.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:

        def maybe_log_gradients(tag: str):
            if (tb_writer is not None and global_batch_idx_train is not None
                    and global_batch_idx_train % 200 == 0):
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        optimizer.zero_grad()
        (-tot_score).backward()
        maybe_log_gradients('train/grad_norms')
        clip_grad_value_(model.parameters(), 5.0)
        maybe_log_gradients('train/clipped_grad_norms')
        if tb_writer is not None and global_batch_idx_train % 200 == 0:
            # Once in a time we will perform a more costly diagnostic
            # to check the relative parameter change per minibatch.
            deltas = optim_step_and_measure_param_change(model, optimizer)
            tb_writer.add_scalars('train/relative_param_change_per_minibatch',
                                  deltas,
                                  global_step=global_batch_idx_train)
        else:
            optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
Esempio n. 9
0
def main():
    args = get_parser().parse_args()
    print('World size:', args.world_size, 'Rank:', args.local_rank)
    setup_dist(rank=args.local_rank, world_size=args.world_size)
    fix_random_seed(42)

    start_epoch = 0
    num_epochs = 10
    use_adam = True

    exp_dir = f'exp-lstm-adam-mmi-bigram-musan-dist'
    setup_logger('{}/log/log-train'.format(exp_dir),
                 use_console=args.local_rank == 0)
    tb_writer = SummaryWriter(
        log_dir=f'{exp_dir}/tensorboard') if args.local_rank == 0 else None

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
    word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')

    logging.info("Loading L.fst")
    if (lang_dir / 'Linv.pt').exists():
        L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt'))
    else:
        with open(lang_dir / 'L.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
            L_inv = k2.arc_sort(L.invert_())
            torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt')

    graph_compiler = MmiTrainingGraphCompiler(L_inv=L_inv,
                                              phones=phone_symbol_table,
                                              words=word_symbol_table)
    phone_ids = get_phone_symbols(phone_symbol_table)
    P = create_bigram_phone_lm(phone_ids)
    P.scores = torch.zeros_like(P.scores)

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz')
    logging.info("About to get dev cuts")
    cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz')
    logging.info("About to get Musan cuts")
    cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
    if not args.bucketing_sampler:
        # We don't mix concatenating the cuts and bucketing
        # Here we insert concatenation before mixing so that the
        # noises from Musan are mixed onto almost-zero-energy
        # padding frames.
        transforms = [CutConcatenate()] + transforms
    train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms)
    if args.bucketing_sampler:
        logging.info('Using BucketingSampler.')
        train_sampler = BucketingSampler(cuts_train,
                                         max_frames=40000,
                                         shuffle=True,
                                         num_buckets=30)
    else:
        logging.info('Using regular sampler with cut concatenation.')
        train_sampler = SingleCutSampler(
            cuts_train,
            max_frames=30000,
            shuffle=True,
        )
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(train,
                                           sampler=train_sampler,
                                           batch_size=None,
                                           num_workers=4)
    logging.info("About to create dev dataset")
    validate = K2SpeechRecognitionDataset(cuts_dev)
    # Note: we explicitly set world_size to 1 to disable the auto-detection of
    #       distributed training inside the sampler. This way, every GPU will
    #       perform the computation on the full dev set. It is a bit wasteful,
    #       but unfortunately loss aggregation between multiple processes with
    #       torch.distributed.all_reduce() tends to hang indefinitely inside
    #       NCCL after ~3000 steps. With the current approach, we can still report
    #       the loss on the full validation set.
    valid_sampler = SingleCutSampler(cuts_dev,
                                     max_frames=90000,
                                     world_size=1,
                                     rank=0)
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(validate,
                                           sampler=valid_sampler,
                                           batch_size=None,
                                           num_workers=1)

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    device_id = args.local_rank
    device = torch.device('cuda', device_id)
    model = TdnnLstm1b(
        num_features=40,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=3)
    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    model.to(device)
    describe(model)

    if use_adam:
        learning_rate = 1e-3
        weight_decay = 5e-4
        optimizer = optim.AdamW(model.parameters(),
                                lr=learning_rate,
                                weight_decay=weight_decay)
        # Equivalent to the following in the epoch loop:
        #  if epoch > 6:
        #      curr_learning_rate *= 0.8
        lr_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lambda ep: 1.0 if ep < 7 else 0.8**(ep - 6))
    else:
        learning_rate = 5e-5
        weight_decay = 1e-5
        momentum = 0.9
        lr_schedule_gamma = 0.7
        optimizer = optim.SGD(model.parameters(),
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer=optimizer, gamma=lr_schedule_gamma)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer,
                               scheduler=lr_scheduler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    if args.world_size > 1:
        logging.info(
            'Using DistributedDataParallel in training. '
            'The reported loss, num_frames, etc. for training steps include '
            'only the batches seen in the master process (the actual loss '
            'includes batches from all GPUs, and the actual num_frames is '
            f'approx. {args.world_size}x larger.')
        # For now do not sync BatchNorm across GPUs due to NCCL hanging in all_gather...
        # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank)

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)

        # LR scheduler can hold multiple learning rates for multiple parameter groups;
        # For now we report just the first LR which we assume concerns most of the parameters.
        curr_learning_rate = lr_scheduler.get_last_lr()[0]
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
        )

        lr_scheduler.step()

        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            optimizer=None,
                            scheduler=None,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            local_rank=args.local_rank,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        optimizer=optimizer,
                        scheduler=lr_scheduler,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        local_rank=args.local_rank,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
    cleanup_dist()
Esempio n. 10
0
def main():
    fix_random_seed(42)

    start_epoch = 0
    num_epochs = 10
    use_adam = True

    exp_dir = f'exp-lstm-adam-mmi-bigram-musan'
    setup_logger('{}/log/log-train'.format(exp_dir))
    tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
    word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')

    logging.info("Loading L.fst")
    if (lang_dir / 'Linv.pt').exists():
        L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt'))
    else:
        with open(lang_dir / 'L.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
            L_inv = k2.arc_sort(L.invert_())
            torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt')

    graph_compiler = MmiTrainingGraphCompiler(L_inv=L_inv,
                                              phones=phone_symbol_table,
                                              words=word_symbol_table)
    phone_ids = get_phone_symbols(phone_symbol_table)
    P = create_bigram_phone_lm(phone_ids)
    P.scores = torch.zeros_like(P.scores)

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = CutSet.from_json(feature_dir / 'cuts_train.json.gz')
    logging.info("About to get dev cuts")
    cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz')
    logging.info("About to get Musan cuts")
    cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    train = K2SpeechRecognitionDataset(cuts_train,
                                       cut_transforms=[
                                           CutConcatenate(),
                                           CutMix(cuts=cuts_musan,
                                                  prob=0.5,
                                                  snr=(10, 20))
                                       ])
    train_sampler = SingleCutSampler(
        cuts_train,
        max_frames=12000,
        shuffle=True,
    )
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(train,
                                           sampler=train_sampler,
                                           batch_size=None,
                                           num_workers=4)
    logging.info("About to create dev dataset")
    validate = K2SpeechRecognitionDataset(cuts_dev)
    valid_sampler = SingleCutSampler(cuts_dev, max_frames=12000)
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(validate,
                                           sampler=valid_sampler,
                                           batch_size=None,
                                           num_workers=1)

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    device_id = 0
    device = torch.device('cuda', device_id)
    model = TdnnLstm1b(
        num_features=40,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=3)
    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    model.to(device)
    describe(model)

    if use_adam:
        learning_rate = 1e-3
        weight_decay = 5e-4
        optimizer = optim.AdamW(model.parameters(),
                                lr=learning_rate,
                                weight_decay=weight_decay)
        # Equivalent to the following in the epoch loop:
        #  if epoch > 6:
        #      curr_learning_rate *= 0.8
        lr_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lambda ep: 1.0 if ep < 7 else 0.8**(ep - 6))
    else:
        learning_rate = 5e-5
        weight_decay = 1e-5
        momentum = 0.9
        lr_schedule_gamma = 0.7
        optimizer = optim.SGD(model.parameters(),
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer=optimizer, gamma=lr_schedule_gamma)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer,
                               scheduler=lr_scheduler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)
        # LR scheduler can hold multiple learning rates for multiple parameter groups;
        # For now we report just the first LR which we assume concerns most of the parameters.
        curr_learning_rate = lr_scheduler.get_last_lr()[0]
        tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                             global_batch_idx_train)
        tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
        )

        lr_scheduler.step()

        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            optimizer=None,
                            scheduler=None,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        optimizer=optimizer,
                        scheduler=lr_scheduler,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Esempio n. 11
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num, den = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num, den = graph_compiler.compile(texts, P)

    assert num.requires_grad == is_training
    assert den.requires_grad is False
    num = num.to(device)
    den = den.to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
    den = k2.intersect_dense(den, dense_fsa_vec, 10.0)

    num_tot_scores = num.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    den_tot_scores = den.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:
        optimizer.zero_grad()
        (-tot_score).backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
Esempio n. 12
0
def main():
    fix_random_seed(42)

    exp_dir = f'exp-lstm-adam-mmi-bigram-musan'
    setup_logger('{}/log/log-train'.format(exp_dir))
    tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
    word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')

    logging.info("Loading L.fst")
    if (lang_dir / 'Linv.pt').exists():
        L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt'))
    else:
        with open(lang_dir / 'L.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
            L_inv = k2.arc_sort(L.invert_())
            torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt')

    graph_compiler = MmiTrainingGraphCompiler(L_inv=L_inv,
                                              phones=phone_symbol_table,
                                              words=word_symbol_table)
    phone_ids = get_phone_symbols(phone_symbol_table)
    P = create_bigram_phone_lm(phone_ids)
    P.scores = torch.zeros_like(P.scores)

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz')
    logging.info("About to get dev cuts")
    cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz')
    logging.info("About to get Musan cuts")
    cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    train = K2SpeechRecognitionIterableDataset(cuts_train,
                                               max_frames=30000,
                                               shuffle=True,
                                               aug_cuts=cuts_musan,
                                               aug_prob=0.5,
                                               aug_snr=(10, 20))
    logging.info("About to create dev dataset")
    validate = K2SpeechRecognitionIterableDataset(cuts_dev,
                                                  max_frames=30000,
                                                  shuffle=False,
                                                  concat_cuts=False)
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(train,
                                           batch_size=None,
                                           num_workers=2)
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(validate,
                                           batch_size=None,
                                           num_workers=1)

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    device_id = 0
    device = torch.device('cuda', device_id)
    model = TdnnLstm1b(
        num_features=40,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=3)
    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    learning_rate = 1e-3
    start_epoch = 0
    num_epochs = 10
    best_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only
    global_batch_idx_valid = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        (epoch, learning_rate, objf) = load_checkpoint(filename=model_path,
                                                       model=model)
        best_objf = objf
        logging.info("epoch = {}, objf = {}".format(epoch, objf))

    model.to(device)
    describe(model)

    #  optimizer = optim.SGD(model.parameters(),
    #                       lr=learning_rate,
    #                       momentum=0.9,
    #                       weight_decay=5e-4)
    optimizer = optim.AdamW(
        model.parameters(),
        # lr=learning_rate,
        weight_decay=5e-4)

    curr_learning_rate = learning_rate
    for epoch in range(start_epoch, num_epochs):
        # curr_learning_rate = learning_rate * pow(0.4, epoch)
        if epoch > 6:
            curr_learning_rate *= 0.8
        for param_group in optimizer.param_groups:
            param_group['lr'] = curr_learning_rate

        tb_writer.add_scalar('learning_rate', curr_learning_rate, epoch)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf = train_one_epoch(dataloader=train_dl,
                               valid_dataloader=valid_dl,
                               model=model,
                               P=P,
                               device=device,
                               graph_compiler=graph_compiler,
                               optimizer=optimizer,
                               current_epoch=epoch,
                               tb_writer=tb_writer,
                               num_epochs=num_epochs,
                               global_batch_idx_train=global_batch_idx_train,
                               global_batch_idx_valid=global_batch_idx_valid)
        # the lower, the better
        if objf < best_objf:
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=best_objf,
                               best_objf=best_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Esempio n. 13
0
def main():
    args = get_parser().parse_args()
    print('World size:', args.world_size, 'Rank:', args.local_rank)
    setup_dist(rank=args.local_rank,
               world_size=args.world_size,
               master_port=args.master_port)
    fix_random_seed(42)

    start_epoch = 0
    num_epochs = 10
    use_adam = True

    exp_dir = f'exp-lstm-adam-mmi-bigram-musan'
    setup_logger('{}/log/log-train'.format(exp_dir))
    tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = args.local_rank
    device = torch.device('cuda', device_id)
    phone_ids = lexicon.phone_symbols()

    if not Path(lang_dir / 'P.pt').is_file():
        logging.debug(f'Loading P from {lang_dir}/P.fst.txt')
        with open(lang_dir / 'P.fst.txt') as f:
            # P is not an acceptor because there is
            # a back-off state, whose incoming arcs
            # have label #0 and aux_label eps.
            P = k2.Fsa.from_openfst(f.read(), acceptor=False)

        phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)

        # P.aux_labels is not needed in later computations, so
        # remove it here.
        del P.aux_labels
        # CAUTION(fangjun): The following line is crucial.
        # Arcs entering the back-off state have label equal to #0.
        # We have to change it to 0 here.
        P.labels[P.labels >= first_phone_disambig_id] = 0

        P = k2.remove_epsilon(P)
        P = k2.arc_sort(P)
        torch.save(P.as_dict(), lang_dir / 'P.pt')
    else:
        logging.debug('Loading pre-compiled P')
        d = torch.load(lang_dir / 'P.pt')
        P = k2.Fsa.from_dict(d)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        P=P,
        device=device,
    )

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = CutSet.from_json(feature_dir / 'cuts_train.json.gz')
    logging.info("About to get dev cuts")
    cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz')
    logging.info("About to get Musan cuts")
    cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    train = K2SpeechRecognitionDataset(cuts_train,
                                       cut_transforms=[
                                           CutConcatenate(),
                                           CutMix(cuts=cuts_musan,
                                                  prob=0.5,
                                                  snr=(10, 20))
                                       ])
    train_sampler = SingleCutSampler(
        cuts_train,
        max_frames=40000,
        shuffle=True,
    )
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(train,
                                           sampler=train_sampler,
                                           batch_size=None,
                                           num_workers=4)
    logging.info("About to create dev dataset")
    validate = K2SpeechRecognitionDataset(cuts_dev)
    valid_sampler = SingleCutSampler(cuts_dev, max_frames=12000)
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(validate,
                                           sampler=valid_sampler,
                                           batch_size=None,
                                           num_workers=1)

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    device_id = 0
    device = torch.device('cuda', device_id)
    model = TdnnLstm1b(
        num_features=40,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=3)
    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    model.to(device)
    describe(model)

    if use_adam:
        learning_rate = 1e-3
        weight_decay = 5e-4
        optimizer = optim.AdamW(model.parameters(),
                                lr=learning_rate,
                                weight_decay=weight_decay)
        # Equivalent to the following in the epoch loop:
        #  if epoch > 6:
        #      curr_learning_rate *= 0.8
        lr_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lambda ep: 1.0 if ep < 7 else 0.8**(ep - 6))
    else:
        learning_rate = 5e-5
        weight_decay = 1e-5
        momentum = 0.9
        lr_schedule_gamma = 0.7
        optimizer = optim.SGD(model.parameters(),
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer=optimizer, gamma=lr_schedule_gamma)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer,
                               scheduler=lr_scheduler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)
        # LR scheduler can hold multiple learning rates for multiple parameter groups;
        # For now we report just the first LR which we assume concerns most of the parameters.
        curr_learning_rate = lr_scheduler.get_last_lr()[0]
        tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                             global_batch_idx_train)
        tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
        )

        lr_scheduler.step()

        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            optimizer=None,
                            scheduler=None,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        optimizer=optimizer,
                        scheduler=lr_scheduler,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Esempio n. 14
0
def main():
    args = get_parser().parse_args()

    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    max_frames = args.max_frames
    accum_grad = args.accum_grad
    den_scale = args.den_scale
    att_rate = args.att_rate

    fix_random_seed(42)

    exp_dir = Path('exp-transformer-noam-mmi-att-musan')
    setup_logger('{}/log/log-train'.format(exp_dir))
    tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
    word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')

    logging.info("Loading L.fst")
    if (lang_dir / 'Linv.pt').exists():
        L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt'))
    else:
        with open(lang_dir / 'L.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
            L_inv = k2.arc_sort(L.invert_())
            torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt')

    graph_compiler = MmiTrainingGraphCompiler(L_inv=L_inv,
                                              phones=phone_symbol_table,
                                              words=word_symbol_table)
    phone_ids = get_phone_symbols(phone_symbol_table)
    P = create_bigram_phone_lm(phone_ids)
    P.scores = torch.zeros_like(P.scores)

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz')
    logging.info("About to get dev cuts")
    cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz')
    logging.info("About to get Musan cuts")
    cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
    if not args.bucketing_sampler:
        # We don't mix concatenating the cuts and bucketing
        # Here we insert concatenation before mixing so that the
        # noises from Musan are mixed onto almost-zero-energy
        # padding frames.
        transforms = [CutConcatenate()] + transforms
    train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms)
    if args.bucketing_sampler:
        logging.info('Using BucketingSampler.')
        train_sampler = BucketingSampler(cuts_train,
                                         max_frames=max_frames,
                                         shuffle=True,
                                         num_buckets=args.num_buckets)
    else:
        logging.info('Using regular sampler with cut concatenation.')
        train_sampler = SingleCutSampler(
            cuts_train,
            max_frames=max_frames,
            shuffle=True,
        )
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(train,
                                           sampler=train_sampler,
                                           batch_size=None,
                                           num_workers=4)
    logging.info("About to create dev dataset")
    validate = K2SpeechRecognitionDataset(cuts_dev)
    valid_sampler = SingleCutSampler(cuts_dev, max_frames=max_frames)
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(validate,
                                           sampler=valid_sampler,
                                           batch_size=None,
                                           num_workers=1)

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    device_id = 0
    device = torch.device('cuda', device_id)

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    model = Transformer(
        num_features=40,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=4,
        num_decoder_layers=num_decoder_layers)

    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    model.to(device)
    describe(model)

    optimizer = Noam(model.parameters(),
                     model_size=256,
                     factor=1.0,
                     warm_step=args.warm_step)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                             global_batch_idx_train)
        tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
        )
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Esempio n. 15
0
def run(rank, world_size, args):
    '''
    Args:
      rank:
        It is a value between 0 and `world_size-1`, which is
        passed automatically by `mp.spawn()` in :func:`main`.
        The node with rank 0 is responsible for saving checkpoint.
      world_size:
        Number of GPUs for DDP training.
      args:
        The return value of get_parser().parse_args()
    '''
    model_type = args.model_type
    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    accum_grad = args.accum_grad
    den_scale = args.den_scale
    att_rate = args.att_rate

    fix_random_seed(42)
    setup_dist(rank, world_size, args.master_port)

    exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa')
    setup_logger(f'{exp_dir}/log/log-train-{rank}')
    if args.tensorboard and rank == 0:
        tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
    else:
        tb_writer = None
    #  tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None

    logging.info("Loading lexicon and symbol tables")
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = rank
    device = torch.device('cuda', device_id)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        device=device,
    )
    phone_ids = lexicon.phone_symbols()
    P = create_bigram_phone_lm(phone_ids)
    P.scores = torch.zeros_like(P.scores)
    P = P.to(device)

    librispeech = LibriSpeechAsrDataModule(args)
    train_dl = librispeech.train_dataloaders()
    valid_dl = librispeech.valid_dataloaders()

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)
    else:
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)

    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    model.to(device)
    describe(model)

    model = DDP(model, device_ids=[rank])

    optimizer = Noam(model.parameters(),
                     model_size=args.attention_dim,
                     factor=1.0,
                     warm_step=args.warm_step)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_dl.sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
            world_size=world_size,
        )
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train,
                            local_rank=rank)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch,
                               local_rank=rank)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train,
                        local_rank=rank)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch,
                           local_rank=rank)

    logging.warning('Done')
    torch.distributed.barrier()
    # NOTE: The training process is very likely to hang at this point.
    # If you press ctrl + c, your GPU memory will not be freed.
    # To free you GPU memory, you can run:
    #
    #  $ ps aux | grep multi
    #
    # And it will print something like below:
    #
    # kuangfa+  430518 98.9  0.6 57074236 3425732 pts/21 Rl Apr02 639:01 /root/fangjun/py38/bin/python3 -c from multiprocessing.spawn
    #
    # You can kill the process manually by:
    #
    # $ kill -9 430518
    #
    # And you will see that your GPU is now not occupied anymore.
    cleanup_dist()